Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import time | |
| from dataclasses import dataclass, asdict | |
| from pathlib import Path | |
| from statistics import mean | |
| from baselines.heuristics import greedy_queue_policy | |
| from workflow_twin.environment import WorkflowTwinEnv | |
| from workflow_twin.memory import MemoryBoundedEnv | |
| class EvalResult: | |
| test_name: str | |
| mode: str | |
| episodes: int | |
| avg_reward: float | |
| success_rate: float | |
| avg_sla_violations: float | |
| avg_memory_used: float | |
| avg_memory_budget: float | |
| memory_compliance_rate: float | |
| steps_per_sec: float | |
| def run_control_test(mode: str, episodes: int = 100) -> EvalResult: | |
| rewards = [] | |
| success_ratios = [] | |
| sla_violations = [] | |
| memory_used = [] | |
| memory_budget = [] | |
| compliance = [] | |
| total_steps = 0 | |
| start = time.perf_counter() | |
| for i in range(episodes): | |
| base_env = WorkflowTwinEnv(level=1, seed=100 + i) | |
| env = MemoryBoundedEnv(base_env, memory_budget=100_000, bits=3, mode=mode) | |
| obs = env.reset() | |
| done = False | |
| ep_reward = 0.0 | |
| steps = 0 | |
| last_info = {} | |
| while not done and steps < 10: | |
| action = greedy_queue_policy(obs) | |
| obs, reward, done, info = env.step(action.model_dump()) | |
| ep_reward += reward | |
| steps += 1 | |
| total_steps += 1 | |
| last_info = info | |
| rewards.append(ep_reward) | |
| state = base_env.state() | |
| completed = state["queue"]["completed"] | |
| waiting = state["queue"]["waiting"] | |
| current = 1 if state["queue"]["current_ticket"] else 0 | |
| total = max(completed + waiting + current, 1) | |
| success_ratios.append(completed / total) | |
| sla_violations.append(last_info.get("sla_violations", 0)) | |
| memory_used.append(last_info.get("memory", {}).get("memory_used", 0)) | |
| memory_budget.append(last_info.get("memory", {}).get("memory_budget", 0)) | |
| compliance.append(memory_used[-1] <= memory_budget[-1]) | |
| elapsed = time.perf_counter() - start | |
| return EvalResult( | |
| test_name="control_no_memory_pressure", | |
| mode=mode, | |
| episodes=episodes, | |
| avg_reward=round(mean(rewards), 4), | |
| success_rate=round(mean(success_ratios), 4), | |
| avg_sla_violations=round(mean(sla_violations), 4), | |
| avg_memory_used=round(mean(memory_used), 2), | |
| avg_memory_budget=round(mean(memory_budget), 2), | |
| memory_compliance_rate=round(mean(1.0 if v else 0.0 for v in compliance), 4), | |
| steps_per_sec=round(total_steps / max(elapsed, 1e-6), 2), | |
| ) | |
| def run_memory_test(mode: str, episodes: int = 100) -> EvalResult: | |
| rewards = [] | |
| success_ratios = [] | |
| sla_violations = [] | |
| memory_used = [] | |
| memory_budget = [] | |
| compliance = [] | |
| total_steps = 0 | |
| start = time.perf_counter() | |
| for i in range(episodes): | |
| base_env = WorkflowTwinEnv(level=5, seed=500 + i) | |
| env = MemoryBoundedEnv(base_env, memory_budget=3500, bits=3, mode=mode) | |
| obs = env.reset() | |
| done = False | |
| ep_reward = 0.0 | |
| steps = 0 | |
| last_info = {} | |
| while not done and steps < 80: | |
| action = greedy_queue_policy(obs) | |
| obs, reward, done, info = env.step(action.model_dump()) | |
| ep_reward += reward | |
| steps += 1 | |
| total_steps += 1 | |
| last_info = info | |
| rewards.append(ep_reward) | |
| state = base_env.state() | |
| completed = state["queue"]["completed"] | |
| waiting = state["queue"]["waiting"] | |
| current = 1 if state["queue"]["current_ticket"] else 0 | |
| total = max(completed + waiting + current, 1) | |
| success_ratios.append(completed / total) | |
| sla_violations.append(last_info.get("sla_violations", 0)) | |
| memory_used.append(last_info.get("memory", {}).get("memory_used", 0)) | |
| memory_budget.append(last_info.get("memory", {}).get("memory_budget", 0)) | |
| compliance.append(memory_used[-1] <= memory_budget[-1]) | |
| elapsed = time.perf_counter() - start | |
| return EvalResult( | |
| test_name="critical_memory_constrained_long_horizon", | |
| mode=mode, | |
| episodes=episodes, | |
| avg_reward=round(mean(rewards), 4), | |
| success_rate=round(mean(success_ratios), 4), | |
| avg_sla_violations=round(mean(sla_violations), 4), | |
| avg_memory_used=round(mean(memory_used), 2), | |
| avg_memory_budget=round(mean(memory_budget), 2), | |
| memory_compliance_rate=round(mean(1.0 if v else 0.0 for v in compliance), 4), | |
| steps_per_sec=round(total_steps / max(elapsed, 1e-6), 2), | |
| ) | |
| def run_memory_budget_sweep( | |
| budgets: list[int] | None = None, | |
| episodes: int = 80, | |
| ) -> list[dict]: | |
| if budgets is None: | |
| budgets = [2000, 3000, 4000, 6000] | |
| sweep_results: list[dict] = [] | |
| for budget in budgets: | |
| for mode in ["baseline", "quant"]: | |
| rewards = [] | |
| success_ratios = [] | |
| compliance = [] | |
| start = time.perf_counter() | |
| total_steps = 0 | |
| for i in range(episodes): | |
| base_env = WorkflowTwinEnv(level=5, seed=1200 + budget + i) | |
| env = MemoryBoundedEnv(base_env, memory_budget=budget, bits=3, mode=mode) | |
| obs = env.reset() | |
| done = False | |
| ep_reward = 0.0 | |
| last_info = {} | |
| steps = 0 | |
| while not done and steps < 80: | |
| action = greedy_queue_policy(obs) | |
| obs, reward, done, info = env.step(action.model_dump()) | |
| ep_reward += reward | |
| last_info = info | |
| steps += 1 | |
| total_steps += 1 | |
| rewards.append(ep_reward) | |
| state = base_env.state() | |
| completed = state["queue"]["completed"] | |
| waiting = state["queue"]["waiting"] | |
| current = 1 if state["queue"]["current_ticket"] else 0 | |
| total = max(completed + waiting + current, 1) | |
| success_ratios.append(completed / total) | |
| mem_used = last_info.get("memory", {}).get("memory_used", 0) | |
| mem_budget = last_info.get("memory", {}).get("memory_budget", budget) | |
| compliance.append(mem_used <= mem_budget) | |
| elapsed = time.perf_counter() - start | |
| sweep_results.append( | |
| { | |
| "test_name": "memory_budget_sweep", | |
| "budget": budget, | |
| "mode": mode, | |
| "episodes": episodes, | |
| "avg_reward": round(mean(rewards), 4), | |
| "success_rate": round(mean(success_ratios), 4), | |
| "memory_compliance_rate": round(mean(1.0 if ok else 0.0 for ok in compliance), 4), | |
| "steps_per_sec": round(total_steps / max(elapsed, 1e-6), 2), | |
| } | |
| ) | |
| return sweep_results | |
| def export_compliance_figure(sweep_rows: list[dict], out_path: Path) -> None: | |
| budgets = sorted({row["budget"] for row in sweep_rows}) | |
| baseline = {row["budget"]: row["memory_compliance_rate"] for row in sweep_rows if row["mode"] == "baseline"} | |
| quant = {row["budget"]: row["memory_compliance_rate"] for row in sweep_rows if row["mode"] == "quant"} | |
| width, height = 760, 440 | |
| left, right, top, bottom = 80, 40, 40, 70 | |
| plot_w = width - left - right | |
| plot_h = height - top - bottom | |
| min_b, max_b = min(budgets), max(budgets) | |
| def x_pos(budget: int) -> float: | |
| if max_b == min_b: | |
| return left + plot_w / 2 | |
| return left + (budget - min_b) / (max_b - min_b) * plot_w | |
| def y_pos(rate: float) -> float: | |
| rate = max(0.0, min(1.0, rate)) | |
| return top + (1.0 - rate) * plot_h | |
| def polyline(points: list[tuple[float, float]], color: str) -> str: | |
| pts = " ".join(f"{x:.2f},{y:.2f}" for x, y in points) | |
| circles = "\n".join( | |
| f'<circle cx="{x:.2f}" cy="{y:.2f}" r="4" fill="{color}" />' for x, y in points | |
| ) | |
| return ( | |
| f'<polyline fill="none" stroke="{color}" stroke-width="3" points="{pts}" />\n' | |
| f"{circles}" | |
| ) | |
| baseline_points = [(x_pos(b), y_pos(baseline.get(b, 0.0))) for b in budgets] | |
| quant_points = [(x_pos(b), y_pos(quant.get(b, 0.0))) for b in budgets] | |
| y_ticks = [0.0, 0.25, 0.5, 0.75, 1.0] | |
| y_grid = "\n".join( | |
| f'<line x1="{left}" y1="{y_pos(v):.2f}" x2="{left+plot_w}" y2="{y_pos(v):.2f}" ' | |
| f'stroke="#e5e7eb" stroke-width="1" />\n' | |
| f'<text x="{left-10}" y="{y_pos(v)+5:.2f}" text-anchor="end" font-size="12" fill="#374151">{v:.2f}</text>' | |
| for v in y_ticks | |
| ) | |
| x_labels = "\n".join( | |
| f'<text x="{x_pos(b):.2f}" y="{top+plot_h+24}" text-anchor="middle" font-size="12" fill="#374151">{b}</text>' | |
| for b in budgets | |
| ) | |
| svg = f'''<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}"> | |
| <rect x="0" y="0" width="{width}" height="{height}" fill="white"/> | |
| <text x="{width/2}" y="24" text-anchor="middle" font-size="18" font-family="Arial" fill="#111827">Memory Budget vs Compliance Rate</text> | |
| {y_grid} | |
| <line x1="{left}" y1="{top}" x2="{left}" y2="{top+plot_h}" stroke="#6b7280" stroke-width="1.5" /> | |
| <line x1="{left}" y1="{top+plot_h}" x2="{left+plot_w}" y2="{top+plot_h}" stroke="#6b7280" stroke-width="1.5" /> | |
| {x_labels} | |
| <text x="{left + plot_w/2}" y="{height-18}" text-anchor="middle" font-size="13" fill="#111827">Memory Budget</text> | |
| <text x="18" y="{top + plot_h/2}" text-anchor="middle" font-size="13" fill="#111827" transform="rotate(-90,18,{top + plot_h/2})">Compliance Rate</text> | |
| {polyline(baseline_points, '#ef4444')} | |
| {polyline(quant_points, '#2563eb')} | |
| <rect x="{width-190}" y="50" width="160" height="54" fill="#f9fafb" stroke="#d1d5db" /> | |
| <line x1="{width-176}" y1="68" x2="{width-146}" y2="68" stroke="#ef4444" stroke-width="3" /> | |
| <text x="{width-140}" y="72" font-size="12" fill="#111827">baseline</text> | |
| <line x1="{width-176}" y1="88" x2="{width-146}" y2="88" stroke="#2563eb" stroke-width="3" /> | |
| <text x="{width-140}" y="92" font-size="12" fill="#111827">quant</text> | |
| </svg> | |
| ''' | |
| out_path.write_text(svg, encoding="utf-8") | |
| def main() -> None: | |
| results = [] | |
| for mode in ["baseline", "quant"]: | |
| results.append(run_control_test(mode=mode)) | |
| for mode in ["baseline", "quant"]: | |
| results.append(run_memory_test(mode=mode)) | |
| sweep_rows = run_memory_budget_sweep() | |
| results.extend(sweep_rows) | |
| figure_path = Path(__file__).resolve().parent / "figures" / "memory_budget_vs_compliance.svg" | |
| export_compliance_figure(sweep_rows, figure_path) | |
| payload = [] | |
| for row in results: | |
| if isinstance(row, EvalResult): | |
| payload.append(asdict(row)) | |
| else: | |
| payload.append(row) | |
| print(json.dumps(payload, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |