from __future__ import annotations
import json
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from statistics import mean
from baselines.heuristics import greedy_queue_policy
from workflow_twin.environment import WorkflowTwinEnv
from workflow_twin.memory import MemoryBoundedEnv
@dataclass
class EvalResult:
test_name: str
mode: str
episodes: int
avg_reward: float
success_rate: float
avg_sla_violations: float
avg_memory_used: float
avg_memory_budget: float
memory_compliance_rate: float
steps_per_sec: float
def run_control_test(mode: str, episodes: int = 100) -> EvalResult:
rewards = []
success_ratios = []
sla_violations = []
memory_used = []
memory_budget = []
compliance = []
total_steps = 0
start = time.perf_counter()
for i in range(episodes):
base_env = WorkflowTwinEnv(level=1, seed=100 + i)
env = MemoryBoundedEnv(base_env, memory_budget=100_000, bits=3, mode=mode)
obs = env.reset()
done = False
ep_reward = 0.0
steps = 0
last_info = {}
while not done and steps < 10:
action = greedy_queue_policy(obs)
obs, reward, done, info = env.step(action.model_dump())
ep_reward += reward
steps += 1
total_steps += 1
last_info = info
rewards.append(ep_reward)
state = base_env.state()
completed = state["queue"]["completed"]
waiting = state["queue"]["waiting"]
current = 1 if state["queue"]["current_ticket"] else 0
total = max(completed + waiting + current, 1)
success_ratios.append(completed / total)
sla_violations.append(last_info.get("sla_violations", 0))
memory_used.append(last_info.get("memory", {}).get("memory_used", 0))
memory_budget.append(last_info.get("memory", {}).get("memory_budget", 0))
compliance.append(memory_used[-1] <= memory_budget[-1])
elapsed = time.perf_counter() - start
return EvalResult(
test_name="control_no_memory_pressure",
mode=mode,
episodes=episodes,
avg_reward=round(mean(rewards), 4),
success_rate=round(mean(success_ratios), 4),
avg_sla_violations=round(mean(sla_violations), 4),
avg_memory_used=round(mean(memory_used), 2),
avg_memory_budget=round(mean(memory_budget), 2),
memory_compliance_rate=round(mean(1.0 if v else 0.0 for v in compliance), 4),
steps_per_sec=round(total_steps / max(elapsed, 1e-6), 2),
)
def run_memory_test(mode: str, episodes: int = 100) -> EvalResult:
rewards = []
success_ratios = []
sla_violations = []
memory_used = []
memory_budget = []
compliance = []
total_steps = 0
start = time.perf_counter()
for i in range(episodes):
base_env = WorkflowTwinEnv(level=5, seed=500 + i)
env = MemoryBoundedEnv(base_env, memory_budget=3500, bits=3, mode=mode)
obs = env.reset()
done = False
ep_reward = 0.0
steps = 0
last_info = {}
while not done and steps < 80:
action = greedy_queue_policy(obs)
obs, reward, done, info = env.step(action.model_dump())
ep_reward += reward
steps += 1
total_steps += 1
last_info = info
rewards.append(ep_reward)
state = base_env.state()
completed = state["queue"]["completed"]
waiting = state["queue"]["waiting"]
current = 1 if state["queue"]["current_ticket"] else 0
total = max(completed + waiting + current, 1)
success_ratios.append(completed / total)
sla_violations.append(last_info.get("sla_violations", 0))
memory_used.append(last_info.get("memory", {}).get("memory_used", 0))
memory_budget.append(last_info.get("memory", {}).get("memory_budget", 0))
compliance.append(memory_used[-1] <= memory_budget[-1])
elapsed = time.perf_counter() - start
return EvalResult(
test_name="critical_memory_constrained_long_horizon",
mode=mode,
episodes=episodes,
avg_reward=round(mean(rewards), 4),
success_rate=round(mean(success_ratios), 4),
avg_sla_violations=round(mean(sla_violations), 4),
avg_memory_used=round(mean(memory_used), 2),
avg_memory_budget=round(mean(memory_budget), 2),
memory_compliance_rate=round(mean(1.0 if v else 0.0 for v in compliance), 4),
steps_per_sec=round(total_steps / max(elapsed, 1e-6), 2),
)
def run_memory_budget_sweep(
budgets: list[int] | None = None,
episodes: int = 80,
) -> list[dict]:
if budgets is None:
budgets = [2000, 3000, 4000, 6000]
sweep_results: list[dict] = []
for budget in budgets:
for mode in ["baseline", "quant"]:
rewards = []
success_ratios = []
compliance = []
start = time.perf_counter()
total_steps = 0
for i in range(episodes):
base_env = WorkflowTwinEnv(level=5, seed=1200 + budget + i)
env = MemoryBoundedEnv(base_env, memory_budget=budget, bits=3, mode=mode)
obs = env.reset()
done = False
ep_reward = 0.0
last_info = {}
steps = 0
while not done and steps < 80:
action = greedy_queue_policy(obs)
obs, reward, done, info = env.step(action.model_dump())
ep_reward += reward
last_info = info
steps += 1
total_steps += 1
rewards.append(ep_reward)
state = base_env.state()
completed = state["queue"]["completed"]
waiting = state["queue"]["waiting"]
current = 1 if state["queue"]["current_ticket"] else 0
total = max(completed + waiting + current, 1)
success_ratios.append(completed / total)
mem_used = last_info.get("memory", {}).get("memory_used", 0)
mem_budget = last_info.get("memory", {}).get("memory_budget", budget)
compliance.append(mem_used <= mem_budget)
elapsed = time.perf_counter() - start
sweep_results.append(
{
"test_name": "memory_budget_sweep",
"budget": budget,
"mode": mode,
"episodes": episodes,
"avg_reward": round(mean(rewards), 4),
"success_rate": round(mean(success_ratios), 4),
"memory_compliance_rate": round(mean(1.0 if ok else 0.0 for ok in compliance), 4),
"steps_per_sec": round(total_steps / max(elapsed, 1e-6), 2),
}
)
return sweep_results
def export_compliance_figure(sweep_rows: list[dict], out_path: Path) -> None:
budgets = sorted({row["budget"] for row in sweep_rows})
baseline = {row["budget"]: row["memory_compliance_rate"] for row in sweep_rows if row["mode"] == "baseline"}
quant = {row["budget"]: row["memory_compliance_rate"] for row in sweep_rows if row["mode"] == "quant"}
width, height = 760, 440
left, right, top, bottom = 80, 40, 40, 70
plot_w = width - left - right
plot_h = height - top - bottom
min_b, max_b = min(budgets), max(budgets)
def x_pos(budget: int) -> float:
if max_b == min_b:
return left + plot_w / 2
return left + (budget - min_b) / (max_b - min_b) * plot_w
def y_pos(rate: float) -> float:
rate = max(0.0, min(1.0, rate))
return top + (1.0 - rate) * plot_h
def polyline(points: list[tuple[float, float]], color: str) -> str:
pts = " ".join(f"{x:.2f},{y:.2f}" for x, y in points)
circles = "\n".join(
f'' for x, y in points
)
return (
f'\n'
f"{circles}"
)
baseline_points = [(x_pos(b), y_pos(baseline.get(b, 0.0))) for b in budgets]
quant_points = [(x_pos(b), y_pos(quant.get(b, 0.0))) for b in budgets]
y_ticks = [0.0, 0.25, 0.5, 0.75, 1.0]
y_grid = "\n".join(
f'\n'
f'{v:.2f}'
for v in y_ticks
)
x_labels = "\n".join(
f'{b}'
for b in budgets
)
svg = f'''
'''
out_path.write_text(svg, encoding="utf-8")
def main() -> None:
results = []
for mode in ["baseline", "quant"]:
results.append(run_control_test(mode=mode))
for mode in ["baseline", "quant"]:
results.append(run_memory_test(mode=mode))
sweep_rows = run_memory_budget_sweep()
results.extend(sweep_rows)
figure_path = Path(__file__).resolve().parent / "figures" / "memory_budget_vs_compliance.svg"
export_compliance_figure(sweep_rows, figure_path)
payload = []
for row in results:
if isinstance(row, EvalResult):
payload.append(asdict(row))
else:
payload.append(row)
print(json.dumps(payload, indent=2))
if __name__ == "__main__":
main()