umar-sharif821's picture
feat: reproducible eval + shared policy module; fix smart_agent wasted-capacity bug
ddf831c
"""Hugging Face Space UI for the CDN Cache Optimizer."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from agents.policies import lru_baseline, smart_agent
from env.cache import CDNCacheEnv, TASK_CONFIGS
from env.models import Action, Observation
@dataclass
class EpisodeMetrics:
rewards: List[float]
hit_rates: List[float]
final_hit_rate: float
total_reward: float
bandwidth_saved_mb: float
def run_episode(task_id: str, seed: int, policy: Callable[[Observation], Action]) -> EpisodeMetrics:
env = CDNCacheEnv(task_id=task_id, seed=seed)
obs = env.reset()
rewards: List[float] = []
hit_rates: List[float] = []
done = False
info: Dict = {}
while not done:
result = env.step(policy(obs))
obs = result.observation
info = result.info
rewards.append(result.reward.total)
hit_rates.append(float(info["hit_rate"]))
done = result.done
return EpisodeMetrics(
rewards=rewards,
hit_rates=hit_rates,
final_hit_rate=float(info.get("hit_rate", 0.0)),
total_reward=float(sum(rewards)),
bandwidth_saved_mb=float(info.get("bandwidth_saved_mb", 0.0)),
)
def make_plot(baseline: EpisodeMetrics, agent: EpisodeMetrics):
fig, axes = plt.subplots(1, 2, figsize=(12, 4.6), dpi=150)
fig.patch.set_facecolor("#0b1220")
for ax in axes:
ax.set_facecolor("#111827")
ax.grid(True, alpha=0.25)
ax.tick_params(colors="#d1d5db")
ax.xaxis.label.set_color("#d1d5db")
ax.yaxis.label.set_color("#d1d5db")
ax.title.set_color("#f9fafb")
x = np.arange(1, len(agent.hit_rates) + 1)
axes[0].plot(x, baseline.hit_rates, color="#fb923c", lw=2, label="Baseline LRU")
axes[0].plot(x, agent.hit_rates, color="#22c55e", lw=2, label="Fine-tuned Agent")
axes[0].set_title("Cache Hit Rate Over Episode")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Hit rate")
axes[0].legend(facecolor="#1f2937", labelcolor="#f9fafb")
labels = ["Reward", "Hit Rate", "Bandwidth Saved"]
baseline_values = [baseline.total_reward, baseline.final_hit_rate * 100, baseline.bandwidth_saved_mb]
agent_values = [agent.total_reward, agent.final_hit_rate * 100, agent.bandwidth_saved_mb]
idx = np.arange(len(labels))
width = 0.36
axes[1].bar(idx - width / 2, baseline_values, width, label="Baseline", color="#fb923c")
axes[1].bar(idx + width / 2, agent_values, width, label="Agent", color="#22c55e")
axes[1].set_xticks(idx)
axes[1].set_xticklabels(labels, rotation=8, ha="right", color="#d1d5db")
axes[1].set_title("Final Comparison")
axes[1].legend(facecolor="#1f2937", labelcolor="#f9fafb")
fig.suptitle("CDN Cache Optimizer: OpenEnv Agent Benchmark", color="#f9fafb", fontweight="bold")
fig.tight_layout()
return fig
def run_demo(task_label: str, seed: int):
task_id = task_label.split(" ")[0]
baseline = run_episode(task_id, int(seed), lru_baseline)
agent = run_episode(task_id, int(seed), smart_agent)
uplift = agent.final_hit_rate - baseline.final_hit_rate
reward_uplift = agent.total_reward - baseline.total_reward
summary = (
f"### Results for `{task_id}`\n"
f"- Baseline LRU reward: **{baseline.total_reward:.2f}**, hit rate: **{baseline.final_hit_rate:.1%}**\n"
f"- Fine-tuned agent reward: **{agent.total_reward:.2f}**, hit rate: **{agent.final_hit_rate:.1%}**\n"
f"- Reward uplift: **{reward_uplift:+.2f}** | Hit-rate uplift: **{uplift:+.1%}**\n\n"
"The agent keeps viral/previewed objects, evicts low-frequency cold content, "
"and avoids unnecessary churn under cache pressure."
)
return summary, make_plot(baseline, agent)
task_choices = [
f"{task_id} - {cfg.name}" for task_id, cfg in TASK_CONFIGS.items()
]
with gr.Blocks(title="CDN Cache Optimizer") as demo:
gr.Markdown(
"""
# CDN Cache Optimizer
OpenEnv-compliant reinforcement-learning environment for edge CDN cache
admission and eviction. The live demo compares an LRU baseline with a
fine-tuned agent policy on realistic steady and viral traffic.
"""
)
with gr.Row():
task = gr.Dropdown(task_choices, value=task_choices[-1], label="OpenEnv task")
seed = gr.Number(value=42, precision=0, label="Seed")
run_btn = gr.Button("Run Benchmark", variant="primary")
output = gr.Markdown()
plot = gr.Plot()
run_btn.click(run_demo, inputs=[task, seed], outputs=[output, plot])
demo.load(run_demo, inputs=[task, seed], outputs=[output, plot])
if __name__ == "__main__":
import os
host = os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")
port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
demo.launch(server_name=host, server_port=port)