| """ |
| baselines.py |
| ============ |
| Deterministic baseline routers and RL-based baselines. |
| |
| Literature: |
| - READYS: Grinsztajn et al. (IEEE Cluster 2021) |
| - EdgeSched-DQN: ScienceDirect 2025 |
| - Das et al. (DAC 2014) β thermal optimization |
| - Lee, Shin, Chwa (ACM TECS 2019) β thermal-aware scheduling |
| """ |
|
|
| import random |
| import math |
| from typing import Dict, List, Optional, Tuple |
| from collections import Counter |
| import numpy as np |
|
|
| try: |
| import torch |
| import torch.nn as nn |
| HAS_TORCH = True |
| except ImportError: |
| HAS_TORCH = False |
|
|
| from profiler import TaskComplexityProfile, TaskComplexityProfiler |
| from rl_env import ComplexityAwarePIMEnv |
|
|
|
|
| class BaselineRouter: |
| """Three deterministic baselines.""" |
| def __init__(self): |
| self.profiler = TaskComplexityProfiler() |
|
|
| def route_always_pim(self, profile: TaskComplexityProfile) -> str: |
| return "PIM" |
|
|
| def route_threshold_rule(self, profile: TaskComplexityProfile, |
| T: float, V_th: float) -> str: |
| if V_th > 0.85: |
| return "CPU" |
| if T > 85.0: |
| return "GPU" |
| if profile.complexity_class == "HEAVY": |
| return "GPU" |
| if profile.is_memory_bound and profile.complexity_class == "LIGHT": |
| return "PIM" |
| return "CPU" |
|
|
| def route_complexity_only(self, profile: TaskComplexityProfile) -> str: |
| scores = self.profiler.compute_suitability_scores(profile) |
| return max(scores, key=scores.get) |
|
|
| def route_standard_dqn(self, state: np.ndarray, policy_net) -> int: |
| with torch.no_grad(): |
| state_t = torch.FloatTensor(state).unsqueeze(0) |
| q_values = policy_net(state_t).cpu().numpy()[0] |
| return int(np.argmax(q_values)) |
|
|
|
|
| |
| |
| |
|
|
| class READYSRouter: |
| """ |
| READYS-inspired greedy heuristic: |
| score = deadline_slack / estimated_exec_time, pick highest. |
| Adapted to our 3-target discrete setting. |
| """ |
| def __init__(self): |
| self.profiler = TaskComplexityProfiler() |
|
|
| def route(self, profile: TaskComplexityProfile, |
| sensor=None, |
| deadline_ms: float = 100.0) -> str: |
| est = {} |
| for t in ["PIM", "CPU", "GPU"]: |
| est[t] = self.profiler.estimate_latency(profile, t) |
| scores = {} |
| for t in ["PIM", "CPU", "GPU"]: |
| slack = deadline_ms - est[t] |
| scores[t] = max(slack, 0.01) / max(est[t], 0.001) |
| |
| if sensor: |
| if getattr(sensor, 'T_current', 25.0) > 85.0: |
| return "GPU" |
| if (hasattr(sensor, 'voltage_history') and sensor.voltage_history and |
| sensor.voltage_history[-1] > 0.85): |
| return "CPU" |
| return max(scores, key=scores.get) |
|
|
|
|
| |
| |
| |
|
|
| class FlatDQN(nn.Module): |
| """Standard (non-dueling) DQN with state+task size inputs.""" |
| def __init__(self, state_dim=16, action_dim=3, hidden_dim=256): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(state_dim, hidden_dim), nn.ReLU(), |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, action_dim), |
| ) |
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class EdgeSchedDQNAgent: |
| """ |
| Flat DQN baseline matching EdgeSched-DQN architecture. |
| No dueling, no PER, no 3-tier hierarchy. |
| """ |
| def __init__(self, state_dim=16, action_dim=3, hidden_dim=256, |
| lr=5e-4, gamma=0.99, tau=0.005, buffer_size=50000, |
| batch_size=128, device="cpu"): |
| if not HAS_TORCH: |
| raise RuntimeError("PyTorch required.") |
| self.device = torch.device(device) |
| self.gamma = gamma |
| self.tau = tau |
| self.batch_size = batch_size |
| self.steps_done = 0 |
| self.policy_net = FlatDQN(state_dim, action_dim, hidden_dim).to(self.device) |
| self.target_net = FlatDQN(state_dim, action_dim, hidden_dim).to(self.device) |
| self.target_net.load_state_dict(self.policy_net.state_dict()) |
| self.target_net.eval() |
| self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=lr) |
| from rl_agent import PrioritizedReplayBuffer |
| self.memory = PrioritizedReplayBuffer(buffer_size, device=device) |
| self.action_dim = action_dim |
|
|
| def select_action(self, state: np.ndarray, epsilon: float = 0.0) -> int: |
| if random.random() < epsilon: |
| return random.randrange(self.action_dim) |
| with torch.no_grad(): |
| q = self.policy_net(torch.FloatTensor(state).unsqueeze(0).to(self.device)) |
| return int(q.argmax(dim=1).item()) |
|
|
| def store_transition(self, *args): |
| self.memory.push(*args) |
|
|
| def train_step(self): |
| if len(self.memory) < self.batch_size: |
| return None |
| states, actions, rewards, next_states, dones, indices, weights = \ |
| self.memory.sample(self.batch_size) |
| current_q = self.policy_net(states).gather(1, actions).squeeze() |
| with torch.no_grad(): |
| next_q = self.target_net(next_states).max(dim=1)[0] |
| target_q = rewards + (1 - dones) * self.gamma * next_q |
| td_errors = (current_q - target_q).detach().cpu().numpy() |
| self.memory.update_priorities(indices, td_errors) |
| loss = (weights * nn.functional.smooth_l1_loss( |
| current_q, target_q, reduction='none')).mean() |
| self.optimizer.zero_grad() |
| loss.backward() |
| nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10.0) |
| self.optimizer.step() |
| for tp, pp in zip(self.target_net.parameters(), self.policy_net.parameters()): |
| tp.data.copy_(self.tau * pp.data + (1 - self.tau) * tp.data) |
| return float(loss.item()) |
|
|
|
|
| |
| |
| |
|
|
| class BaselineEvaluator: |
| ACTION_NAMES = {0: "PIM", 1: "CPU", 2: "GPU"} |
|
|
| def __init__(self, num_eval_episodes: int = 50, max_steps: int = 200): |
| self.num_eval_episodes = num_eval_episodes |
| self.max_steps = max_steps |
| self.baseline = BaselineRouter() |
| self.readys = READYSRouter() |
|
|
| def _run_policy(self, policy_fn, label: str) -> Dict: |
| env = ComplexityAwarePIMEnv(max_steps=self.max_steps) |
| metrics = { |
| "label": label, "rewards": [], "energy_mj": [], |
| "latency_ms": [], "counts": {"PIM": 0, "CPU": 0, "GPU": 0}, |
| "switches": [], |
| } |
| for _ in range(self.num_eval_episodes): |
| state = env.reset() |
| total_r, ep_energy, ep_latency = 0.0, [], [] |
| for _ in range(self.max_steps): |
| action = policy_fn(state, env) |
| state, reward, done, info = env.step(action) |
| target = self.ACTION_NAMES[action] |
| metrics["counts"][target] += 1 |
| total_r += reward |
| prof = env.current_profile |
| ep_energy.append(env.profiler.estimate_energy(prof, target)) |
| ep_latency.append(env.profiler.estimate_latency(prof, target)) |
| if done: |
| break |
| metrics["rewards"].append(total_r) |
| metrics["energy_mj"].append(float(np.mean(ep_energy)) if ep_energy else 0.0) |
| metrics["latency_ms"].append(float(np.mean(ep_latency)) if ep_latency else 0.0) |
| metrics["switches"].append(info["switches"]) |
| return metrics |
|
|
| def evaluate_all(self, trained_agent) -> Dict[str, Dict]: |
| results = {} |
|
|
| def always_pim(state, env): return 0 |
| results["Always-PIM"] = self._run_policy(always_pim, "Always-PIM") |
|
|
| def threshold_rule(state, env): |
| T = env.sensor.T_current |
| V_th = (env.sensor.voltage_history[-1] |
| if env.sensor.voltage_history else 0.6) |
| target = self.baseline.route_threshold_rule(env.current_profile, T, V_th) |
| return {"PIM": 0, "CPU": 1, "GPU": 2}[target] |
| results["Threshold-Rule"] = self._run_policy(threshold_rule, "Threshold-Rule") |
|
|
| def complexity_only(state, env): |
| target = self.baseline.route_complexity_only(env.current_profile) |
| return {"PIM": 0, "CPU": 1, "GPU": 2}[target] |
| results["Complexity-Only"] = self._run_policy(complexity_only, "Complexity-Only") |
|
|
| def readys_route(state, env): |
| target = self.readys.route(env.current_profile, sensor=env.sensor) |
| return {"PIM": 0, "CPU": 1, "GPU": 2}[target] |
| results["READYS"] = self._run_policy(readys_route, "READYS") |
|
|
| def rl_agent(state, env): |
| return trained_agent.select_action( |
| state, sensor=env.sensor, |
| task_profile=env.current_profile, training=False) |
| results["RL-Agent (ours)"] = self._run_policy(rl_agent, "RL-Agent (ours)") |
|
|
| return results |
|
|
| def print_comparison_table(self, results: Dict[str, Dict]) -> None: |
| print("\n" + "=" * 78) |
| print(" BASELINE COMPARISON TABLE") |
| print("=" * 78) |
| header = f" {'Method':<22} {'Avg Reward':>12} {'Avg Energy(mJ)':>16} {'Avg Latency(ms)':>16} {'PIM%':>7}" |
| print(header) |
| print(" " + "-" * 74) |
| for label, m in results.items(): |
| total = sum(m["counts"].values()) |
| pim_pct = m["counts"]["PIM"] / total * 100 if total else 0 |
| print(f" {label:<22} " |
| f"{np.mean(m['rewards']):>12.2f} " |
| f"{np.mean(m['energy_mj']):>16.4f} " |
| f"{np.mean(m['latency_ms']):>16.4f} " |
| f"{pim_pct:>7.1f}%") |
| print("=" * 78) |
|
|
|
|
| |
| |
| |
|
|
| class AblationStudy: |
| """Systematically removes one component at a time.""" |
|
|
| def __init__(self, num_episodes: int = 150, max_steps: int = 200, |
| device: str = "cpu"): |
| self.num_episodes = num_episodes |
| self.max_steps = max_steps |
| self.device = device |
|
|
| def _train_variant(self, variant_name: str, |
| use_dueling: bool = True, |
| use_per: bool = True, |
| use_safety_tier: bool = True, |
| state_dim: int = 16) -> Tuple[float, float]: |
| from rl_env import ComplexityAwarePIMEnv |
| from rl_agent import ComplexityAwareRLAgent, PrioritizedReplayBuffer, Transition |
| env = ComplexityAwarePIMEnv(max_steps=self.max_steps) |
| agent = ComplexityAwareRLAgent( |
| state_dim=state_dim, device=self.device, |
| buffer_size=20000, batch_size=64) |
|
|
| if not use_dueling: |
| class FlatDQN(nn.Module): |
| def __init__(self, state_dim, action_dim=3, hidden_dim=256): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(state_dim, hidden_dim), nn.ReLU(), |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, action_dim), |
| ) |
| def forward(self, x): |
| return self.net(x) |
| agent.policy_net = FlatDQN(state_dim).to(torch.device(self.device)) |
| agent.target_net = FlatDQN(state_dim).to(torch.device(self.device)) |
| agent.target_net.load_state_dict(agent.policy_net.state_dict()) |
| agent.optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=5e-4) |
|
|
| if not use_per: |
| class UniformBuffer: |
| def __init__(self, capacity=20000): |
| self.buf = []; self.capacity = capacity; self.pos = 0; self.size = 0 |
| def push(self, *args): |
| if self.size < self.capacity: |
| self.buf.append(Transition(*args)); self.size += 1 |
| else: |
| self.buf[self.pos] = Transition(*args) |
| self.pos = (self.pos + 1) % self.capacity |
| def sample(self, batch_size, beta=0.4): |
| idxs = np.random.choice(self.size, batch_size, replace=False) |
| samples = [self.buf[i] for i in idxs] |
| weights = torch.ones(batch_size) |
| return Transition(*zip(*samples)), idxs, weights |
| def update_priorities(self, indices, td_errors): pass |
| def __len__(self): return self.size |
| agent.memory = UniformBuffer(20000) |
|
|
| rewards, switches = [], [] |
| for ep in range(self.num_episodes): |
| state = env.reset() |
| if state_dim == 8: |
| state = state[:8] |
| total_r = 0 |
| for _ in range(self.max_steps): |
| if use_safety_tier: |
| action = agent.select_action( |
| state, sensor=env.sensor, |
| task_profile=env.current_profile) |
| else: |
| action = agent.select_action(state, training=True) |
| next_state, reward, done, info = env.step(action) |
| if state_dim == 8: |
| next_state = next_state[:8] |
| agent.store_transition(state, action, reward, next_state, float(done)) |
| agent.train_step() |
| total_r += reward |
| state = next_state |
| if done: |
| break |
| rewards.append(total_r) |
| switches.append(info["switches"]) |
|
|
| last50 = rewards[-50:] if len(rewards) >= 50 else rewards |
| last50_sw = switches[-50:] if len(switches) >= 50 else switches |
| return float(np.mean(last50)), float(np.mean(last50_sw)) |
|
|
| def run(self) -> Dict[str, Dict]: |
| print("\n--- Ablation Study ---") |
| results = {} |
| variants = [ |
| ("Full system", True, True, True, 16), |
| ("No dueling (flat DQN)", False, True, True, 16), |
| ("No PER (uniform replay)", True, False, True, 16), |
| ("No 3-tier hierarchy", True, True, False, 16), |
| ("Physics-only state (8D)", True, True, True, 8), |
| ] |
| for name, dueling, per, safety, sdim in variants: |
| print(f" Training variant: {name}...") |
| r, sw = self._train_variant(name, dueling, per, safety, sdim) |
| results[name] = {"mean_reward": r, "mean_switches": sw} |
| print(f" β mean_reward={r:.2f}, mean_switches={sw:.1f}") |
|
|
| print("\n ABLATION RESULTS (last-50-episode averages):") |
| print(f" {'Variant':<35} {'Mean Reward':>13} {'Mean Switches':>14}") |
| print(" " + "-" * 62) |
| baseline_r = results["Full system"]["mean_reward"] |
| for name, m in results.items(): |
| drop = baseline_r - m["mean_reward"] |
| drop_str = f" (β{drop:.2f})" if drop > 0.1 else "" |
| print(f" {name:<35} {m['mean_reward']:>13.2f}{drop_str}") |
| return results |
|
|