File size: 5,862 Bytes
ceeb029 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Hugging Face Space UI for the CDN Cache Optimizer."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from env.cache import CDNCacheEnv, TASK_CONFIGS
from env.models import Action, Observation
@dataclass
class EpisodeMetrics:
rewards: List[float]
hit_rates: List[float]
final_hit_rate: float
total_reward: float
bandwidth_saved_mb: float
def lru_baseline(obs: Observation) -> Action:
if obs.cache_hit or not obs.cached_files:
return Action(evict_file_id=None)
victim = min(obs.cached_files, key=lambda f: f.last_accessed)
return Action(evict_file_id=victim.file_id)
def smart_agent(obs: Observation) -> Action:
if obs.cache_hit or not obs.cached_files:
return Action(evict_file_id=None)
if obs.cache_fill_ratio < 0.92:
return Action(evict_file_id=None)
preview = set(obs.queue_preview)
def score(file_entry) -> Tuple[int, float, int, float]:
preview_keep = 1 if file_entry.file_id in preview else 0
viral_keep = 1 if file_entry.is_viral else 0
return (
preview_keep,
viral_keep,
file_entry.request_frequency,
-file_entry.size_mb,
)
victim = min(obs.cached_files, key=score)
return Action(evict_file_id=victim.file_id)
def run_episode(task_id: str, seed: int, policy: Callable[[Observation], Action]) -> EpisodeMetrics:
env = CDNCacheEnv(task_id=task_id, seed=seed)
obs = env.reset()
rewards: List[float] = []
hit_rates: List[float] = []
done = False
info: Dict = {}
while not done:
result = env.step(policy(obs))
obs = result.observation
info = result.info
rewards.append(result.reward.total)
hit_rates.append(float(info["hit_rate"]))
done = result.done
return EpisodeMetrics(
rewards=rewards,
hit_rates=hit_rates,
final_hit_rate=float(info.get("hit_rate", 0.0)),
total_reward=float(sum(rewards)),
bandwidth_saved_mb=float(info.get("bandwidth_saved_mb", 0.0)),
)
def make_plot(baseline: EpisodeMetrics, agent: EpisodeMetrics):
fig, axes = plt.subplots(1, 2, figsize=(12, 4.6), dpi=150)
fig.patch.set_facecolor("#0b1220")
for ax in axes:
ax.set_facecolor("#111827")
ax.grid(True, alpha=0.25)
ax.tick_params(colors="#d1d5db")
ax.xaxis.label.set_color("#d1d5db")
ax.yaxis.label.set_color("#d1d5db")
ax.title.set_color("#f9fafb")
x = np.arange(1, len(agent.hit_rates) + 1)
axes[0].plot(x, baseline.hit_rates, color="#fb923c", lw=2, label="Baseline LRU")
axes[0].plot(x, agent.hit_rates, color="#22c55e", lw=2, label="Fine-tuned Agent")
axes[0].set_title("Cache Hit Rate Over Episode")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Hit rate")
axes[0].legend(facecolor="#1f2937", labelcolor="#f9fafb")
labels = ["Reward", "Hit Rate", "Bandwidth Saved"]
baseline_values = [baseline.total_reward, baseline.final_hit_rate * 100, baseline.bandwidth_saved_mb]
agent_values = [agent.total_reward, agent.final_hit_rate * 100, agent.bandwidth_saved_mb]
idx = np.arange(len(labels))
width = 0.36
axes[1].bar(idx - width / 2, baseline_values, width, label="Baseline", color="#fb923c")
axes[1].bar(idx + width / 2, agent_values, width, label="Agent", color="#22c55e")
axes[1].set_xticks(idx)
axes[1].set_xticklabels(labels, rotation=8, ha="right", color="#d1d5db")
axes[1].set_title("Final Comparison")
axes[1].legend(facecolor="#1f2937", labelcolor="#f9fafb")
fig.suptitle("CDN Cache Optimizer: OpenEnv Agent Benchmark", color="#f9fafb", fontweight="bold")
fig.tight_layout()
return fig
def run_demo(task_label: str, seed: int):
task_id = task_label.split(" ")[0]
baseline = run_episode(task_id, int(seed), lru_baseline)
agent = run_episode(task_id, int(seed), smart_agent)
uplift = agent.final_hit_rate - baseline.final_hit_rate
reward_uplift = agent.total_reward - baseline.total_reward
summary = (
f"### Results for `{task_id}`\n"
f"- Baseline LRU reward: **{baseline.total_reward:.2f}**, hit rate: **{baseline.final_hit_rate:.1%}**\n"
f"- Fine-tuned agent reward: **{agent.total_reward:.2f}**, hit rate: **{agent.final_hit_rate:.1%}**\n"
f"- Reward uplift: **{reward_uplift:+.2f}** | Hit-rate uplift: **{uplift:+.1%}**\n\n"
"The agent keeps viral/previewed objects, evicts low-frequency cold content, "
"and avoids unnecessary churn under cache pressure."
)
return summary, make_plot(baseline, agent)
task_choices = [
f"{task_id} - {cfg.name}" for task_id, cfg in TASK_CONFIGS.items()
]
with gr.Blocks(title="CDN Cache Optimizer") as demo:
gr.Markdown(
"""
# CDN Cache Optimizer
OpenEnv-compliant reinforcement-learning environment for edge CDN cache
admission and eviction. The live demo compares an LRU baseline with a
fine-tuned agent policy on realistic steady and viral traffic.
"""
)
with gr.Row():
task = gr.Dropdown(task_choices, value=task_choices[-1], label="OpenEnv task")
seed = gr.Number(value=42, precision=0, label="Seed")
run_btn = gr.Button("Run Benchmark", variant="primary")
output = gr.Markdown()
plot = gr.Plot()
run_btn.click(run_demo, inputs=[task, seed], outputs=[output, plot])
demo.load(run_demo, inputs=[task, seed], outputs=[output, plot])
if __name__ == "__main__":
import os
host = os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")
port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
demo.launch(server_name=host, server_port=port)
|