workflow-twin / experiments /ab_quantized_memory_eval.py
NDGCodes's picture
fix repo structure for HF
1a692ce
from __future__ import annotations
import json
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from statistics import mean
from baselines.heuristics import greedy_queue_policy
from workflow_twin.environment import WorkflowTwinEnv
from workflow_twin.memory import MemoryBoundedEnv
@dataclass
class EvalResult:
test_name: str
mode: str
episodes: int
avg_reward: float
success_rate: float
avg_sla_violations: float
avg_memory_used: float
avg_memory_budget: float
memory_compliance_rate: float
steps_per_sec: float
def run_control_test(mode: str, episodes: int = 100) -> EvalResult:
rewards = []
success_ratios = []
sla_violations = []
memory_used = []
memory_budget = []
compliance = []
total_steps = 0
start = time.perf_counter()
for i in range(episodes):
base_env = WorkflowTwinEnv(level=1, seed=100 + i)
env = MemoryBoundedEnv(base_env, memory_budget=100_000, bits=3, mode=mode)
obs = env.reset()
done = False
ep_reward = 0.0
steps = 0
last_info = {}
while not done and steps < 10:
action = greedy_queue_policy(obs)
obs, reward, done, info = env.step(action.model_dump())
ep_reward += reward
steps += 1
total_steps += 1
last_info = info
rewards.append(ep_reward)
state = base_env.state()
completed = state["queue"]["completed"]
waiting = state["queue"]["waiting"]
current = 1 if state["queue"]["current_ticket"] else 0
total = max(completed + waiting + current, 1)
success_ratios.append(completed / total)
sla_violations.append(last_info.get("sla_violations", 0))
memory_used.append(last_info.get("memory", {}).get("memory_used", 0))
memory_budget.append(last_info.get("memory", {}).get("memory_budget", 0))
compliance.append(memory_used[-1] <= memory_budget[-1])
elapsed = time.perf_counter() - start
return EvalResult(
test_name="control_no_memory_pressure",
mode=mode,
episodes=episodes,
avg_reward=round(mean(rewards), 4),
success_rate=round(mean(success_ratios), 4),
avg_sla_violations=round(mean(sla_violations), 4),
avg_memory_used=round(mean(memory_used), 2),
avg_memory_budget=round(mean(memory_budget), 2),
memory_compliance_rate=round(mean(1.0 if v else 0.0 for v in compliance), 4),
steps_per_sec=round(total_steps / max(elapsed, 1e-6), 2),
)
def run_memory_test(mode: str, episodes: int = 100) -> EvalResult:
rewards = []
success_ratios = []
sla_violations = []
memory_used = []
memory_budget = []
compliance = []
total_steps = 0
start = time.perf_counter()
for i in range(episodes):
base_env = WorkflowTwinEnv(level=5, seed=500 + i)
env = MemoryBoundedEnv(base_env, memory_budget=3500, bits=3, mode=mode)
obs = env.reset()
done = False
ep_reward = 0.0
steps = 0
last_info = {}
while not done and steps < 80:
action = greedy_queue_policy(obs)
obs, reward, done, info = env.step(action.model_dump())
ep_reward += reward
steps += 1
total_steps += 1
last_info = info
rewards.append(ep_reward)
state = base_env.state()
completed = state["queue"]["completed"]
waiting = state["queue"]["waiting"]
current = 1 if state["queue"]["current_ticket"] else 0
total = max(completed + waiting + current, 1)
success_ratios.append(completed / total)
sla_violations.append(last_info.get("sla_violations", 0))
memory_used.append(last_info.get("memory", {}).get("memory_used", 0))
memory_budget.append(last_info.get("memory", {}).get("memory_budget", 0))
compliance.append(memory_used[-1] <= memory_budget[-1])
elapsed = time.perf_counter() - start
return EvalResult(
test_name="critical_memory_constrained_long_horizon",
mode=mode,
episodes=episodes,
avg_reward=round(mean(rewards), 4),
success_rate=round(mean(success_ratios), 4),
avg_sla_violations=round(mean(sla_violations), 4),
avg_memory_used=round(mean(memory_used), 2),
avg_memory_budget=round(mean(memory_budget), 2),
memory_compliance_rate=round(mean(1.0 if v else 0.0 for v in compliance), 4),
steps_per_sec=round(total_steps / max(elapsed, 1e-6), 2),
)
def run_memory_budget_sweep(
budgets: list[int] | None = None,
episodes: int = 80,
) -> list[dict]:
if budgets is None:
budgets = [2000, 3000, 4000, 6000]
sweep_results: list[dict] = []
for budget in budgets:
for mode in ["baseline", "quant"]:
rewards = []
success_ratios = []
compliance = []
start = time.perf_counter()
total_steps = 0
for i in range(episodes):
base_env = WorkflowTwinEnv(level=5, seed=1200 + budget + i)
env = MemoryBoundedEnv(base_env, memory_budget=budget, bits=3, mode=mode)
obs = env.reset()
done = False
ep_reward = 0.0
last_info = {}
steps = 0
while not done and steps < 80:
action = greedy_queue_policy(obs)
obs, reward, done, info = env.step(action.model_dump())
ep_reward += reward
last_info = info
steps += 1
total_steps += 1
rewards.append(ep_reward)
state = base_env.state()
completed = state["queue"]["completed"]
waiting = state["queue"]["waiting"]
current = 1 if state["queue"]["current_ticket"] else 0
total = max(completed + waiting + current, 1)
success_ratios.append(completed / total)
mem_used = last_info.get("memory", {}).get("memory_used", 0)
mem_budget = last_info.get("memory", {}).get("memory_budget", budget)
compliance.append(mem_used <= mem_budget)
elapsed = time.perf_counter() - start
sweep_results.append(
{
"test_name": "memory_budget_sweep",
"budget": budget,
"mode": mode,
"episodes": episodes,
"avg_reward": round(mean(rewards), 4),
"success_rate": round(mean(success_ratios), 4),
"memory_compliance_rate": round(mean(1.0 if ok else 0.0 for ok in compliance), 4),
"steps_per_sec": round(total_steps / max(elapsed, 1e-6), 2),
}
)
return sweep_results
def export_compliance_figure(sweep_rows: list[dict], out_path: Path) -> None:
budgets = sorted({row["budget"] for row in sweep_rows})
baseline = {row["budget"]: row["memory_compliance_rate"] for row in sweep_rows if row["mode"] == "baseline"}
quant = {row["budget"]: row["memory_compliance_rate"] for row in sweep_rows if row["mode"] == "quant"}
width, height = 760, 440
left, right, top, bottom = 80, 40, 40, 70
plot_w = width - left - right
plot_h = height - top - bottom
min_b, max_b = min(budgets), max(budgets)
def x_pos(budget: int) -> float:
if max_b == min_b:
return left + plot_w / 2
return left + (budget - min_b) / (max_b - min_b) * plot_w
def y_pos(rate: float) -> float:
rate = max(0.0, min(1.0, rate))
return top + (1.0 - rate) * plot_h
def polyline(points: list[tuple[float, float]], color: str) -> str:
pts = " ".join(f"{x:.2f},{y:.2f}" for x, y in points)
circles = "\n".join(
f'<circle cx="{x:.2f}" cy="{y:.2f}" r="4" fill="{color}" />' for x, y in points
)
return (
f'<polyline fill="none" stroke="{color}" stroke-width="3" points="{pts}" />\n'
f"{circles}"
)
baseline_points = [(x_pos(b), y_pos(baseline.get(b, 0.0))) for b in budgets]
quant_points = [(x_pos(b), y_pos(quant.get(b, 0.0))) for b in budgets]
y_ticks = [0.0, 0.25, 0.5, 0.75, 1.0]
y_grid = "\n".join(
f'<line x1="{left}" y1="{y_pos(v):.2f}" x2="{left+plot_w}" y2="{y_pos(v):.2f}" '
f'stroke="#e5e7eb" stroke-width="1" />\n'
f'<text x="{left-10}" y="{y_pos(v)+5:.2f}" text-anchor="end" font-size="12" fill="#374151">{v:.2f}</text>'
for v in y_ticks
)
x_labels = "\n".join(
f'<text x="{x_pos(b):.2f}" y="{top+plot_h+24}" text-anchor="middle" font-size="12" fill="#374151">{b}</text>'
for b in budgets
)
svg = f'''<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">
<rect x="0" y="0" width="{width}" height="{height}" fill="white"/>
<text x="{width/2}" y="24" text-anchor="middle" font-size="18" font-family="Arial" fill="#111827">Memory Budget vs Compliance Rate</text>
{y_grid}
<line x1="{left}" y1="{top}" x2="{left}" y2="{top+plot_h}" stroke="#6b7280" stroke-width="1.5" />
<line x1="{left}" y1="{top+plot_h}" x2="{left+plot_w}" y2="{top+plot_h}" stroke="#6b7280" stroke-width="1.5" />
{x_labels}
<text x="{left + plot_w/2}" y="{height-18}" text-anchor="middle" font-size="13" fill="#111827">Memory Budget</text>
<text x="18" y="{top + plot_h/2}" text-anchor="middle" font-size="13" fill="#111827" transform="rotate(-90,18,{top + plot_h/2})">Compliance Rate</text>
{polyline(baseline_points, '#ef4444')}
{polyline(quant_points, '#2563eb')}
<rect x="{width-190}" y="50" width="160" height="54" fill="#f9fafb" stroke="#d1d5db" />
<line x1="{width-176}" y1="68" x2="{width-146}" y2="68" stroke="#ef4444" stroke-width="3" />
<text x="{width-140}" y="72" font-size="12" fill="#111827">baseline</text>
<line x1="{width-176}" y1="88" x2="{width-146}" y2="88" stroke="#2563eb" stroke-width="3" />
<text x="{width-140}" y="92" font-size="12" fill="#111827">quant</text>
</svg>
'''
out_path.write_text(svg, encoding="utf-8")
def main() -> None:
results = []
for mode in ["baseline", "quant"]:
results.append(run_control_test(mode=mode))
for mode in ["baseline", "quant"]:
results.append(run_memory_test(mode=mode))
sweep_rows = run_memory_budget_sweep()
results.extend(sweep_rows)
figure_path = Path(__file__).resolve().parent / "figures" / "memory_budget_vs_compliance.svg"
export_compliance_figure(sweep_rows, figure_path)
payload = []
for row in results:
if isinstance(row, EvalResult):
payload.append(asdict(row))
else:
payload.append(row)
print(json.dumps(payload, indent=2))
if __name__ == "__main__":
main()