""" baselines.py ============ Deterministic baseline routers and RL-based baselines. Literature: - READYS: Grinsztajn et al. (IEEE Cluster 2021) - EdgeSched-DQN: ScienceDirect 2025 - Das et al. (DAC 2014) — thermal optimization - Lee, Shin, Chwa (ACM TECS 2019) — thermal-aware scheduling """ import random import math from typing import Dict, List, Optional, Tuple from collections import Counter import numpy as np try: import torch import torch.nn as nn HAS_TORCH = True except ImportError: HAS_TORCH = False from profiler import TaskComplexityProfile, TaskComplexityProfiler from rl_env import ComplexityAwarePIMEnv class BaselineRouter: """Three deterministic baselines.""" def __init__(self): self.profiler = TaskComplexityProfiler() def route_always_pim(self, profile: TaskComplexityProfile) -> str: return "PIM" def route_threshold_rule(self, profile: TaskComplexityProfile, T: float, V_th: float) -> str: if V_th > 0.85: return "CPU" if T > 85.0: return "GPU" if profile.complexity_class == "HEAVY": return "GPU" if profile.is_memory_bound and profile.complexity_class == "LIGHT": return "PIM" return "CPU" def route_complexity_only(self, profile: TaskComplexityProfile) -> str: scores = self.profiler.compute_suitability_scores(profile) return max(scores, key=scores.get) def route_standard_dqn(self, state: np.ndarray, policy_net) -> int: with torch.no_grad(): state_t = torch.FloatTensor(state).unsqueeze(0) q_values = policy_net(state_t).cpu().numpy()[0] return int(np.argmax(q_values)) # ═══════════════════════════════════════════════════════════════════════════════ # READYS-style Greedy Scheduler (Grinsztajn et al. 2021) # ═══════════════════════════════════════════════════════════════════════════════ class READYSRouter: """ READYS-inspired greedy heuristic: score = deadline_slack / estimated_exec_time, pick highest. Adapted to our 3-target discrete setting. """ def __init__(self): self.profiler = TaskComplexityProfiler() def route(self, profile: TaskComplexityProfile, sensor=None, deadline_ms: float = 100.0) -> str: est = {} for t in ["PIM", "CPU", "GPU"]: est[t] = self.profiler.estimate_latency(profile, t) scores = {} for t in ["PIM", "CPU", "GPU"]: slack = deadline_ms - est[t] scores[t] = max(slack, 0.01) / max(est[t], 0.001) # Safety overrides if sensor: if getattr(sensor, 'T_current', 25.0) > 85.0: return "GPU" if (hasattr(sensor, 'voltage_history') and sensor.voltage_history and sensor.voltage_history[-1] > 0.85): return "CPU" return max(scores, key=scores.get) # ═══════════════════════════════════════════════════════════════════════════════ # EdgeSched-DQN style Flat DQN Baseline # ═══════════════════════════════════════════════════════════════════════════════ class FlatDQN(nn.Module): """Standard (non-dueling) DQN with state+task size inputs.""" def __init__(self, state_dim=16, action_dim=3, hidden_dim=256): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, action_dim), ) def forward(self, x): return self.net(x) class EdgeSchedDQNAgent: """ Flat DQN baseline matching EdgeSched-DQN architecture. No dueling, no PER, no 3-tier hierarchy. """ def __init__(self, state_dim=16, action_dim=3, hidden_dim=256, lr=5e-4, gamma=0.99, tau=0.005, buffer_size=50000, batch_size=128, device="cpu"): if not HAS_TORCH: raise RuntimeError("PyTorch required.") self.device = torch.device(device) self.gamma = gamma self.tau = tau self.batch_size = batch_size self.steps_done = 0 self.policy_net = FlatDQN(state_dim, action_dim, hidden_dim).to(self.device) self.target_net = FlatDQN(state_dim, action_dim, hidden_dim).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=lr) from rl_agent import PrioritizedReplayBuffer self.memory = PrioritizedReplayBuffer(buffer_size, device=device) self.action_dim = action_dim def select_action(self, state: np.ndarray, epsilon: float = 0.0) -> int: if random.random() < epsilon: return random.randrange(self.action_dim) with torch.no_grad(): q = self.policy_net(torch.FloatTensor(state).unsqueeze(0).to(self.device)) return int(q.argmax(dim=1).item()) def store_transition(self, *args): self.memory.push(*args) def train_step(self): if len(self.memory) < self.batch_size: return None states, actions, rewards, next_states, dones, indices, weights = \ self.memory.sample(self.batch_size) current_q = self.policy_net(states).gather(1, actions).squeeze() with torch.no_grad(): next_q = self.target_net(next_states).max(dim=1)[0] target_q = rewards + (1 - dones) * self.gamma * next_q td_errors = (current_q - target_q).detach().cpu().numpy() self.memory.update_priorities(indices, td_errors) loss = (weights * nn.functional.smooth_l1_loss( current_q, target_q, reduction='none')).mean() self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10.0) self.optimizer.step() for tp, pp in zip(self.target_net.parameters(), self.policy_net.parameters()): tp.data.copy_(self.tau * pp.data + (1 - self.tau) * tp.data) return float(loss.item()) # ═══════════════════════════════════════════════════════════════════════════════ # Baseline Evaluator # ═══════════════════════════════════════════════════════════════════════════════ class BaselineEvaluator: ACTION_NAMES = {0: "PIM", 1: "CPU", 2: "GPU"} def __init__(self, num_eval_episodes: int = 50, max_steps: int = 200): self.num_eval_episodes = num_eval_episodes self.max_steps = max_steps self.baseline = BaselineRouter() self.readys = READYSRouter() def _run_policy(self, policy_fn, label: str) -> Dict: env = ComplexityAwarePIMEnv(max_steps=self.max_steps) metrics = { "label": label, "rewards": [], "energy_mj": [], "latency_ms": [], "counts": {"PIM": 0, "CPU": 0, "GPU": 0}, "switches": [], } for _ in range(self.num_eval_episodes): state = env.reset() total_r, ep_energy, ep_latency = 0.0, [], [] for _ in range(self.max_steps): action = policy_fn(state, env) state, reward, done, info = env.step(action) target = self.ACTION_NAMES[action] metrics["counts"][target] += 1 total_r += reward prof = env.current_profile ep_energy.append(env.profiler.estimate_energy(prof, target)) ep_latency.append(env.profiler.estimate_latency(prof, target)) if done: break metrics["rewards"].append(total_r) metrics["energy_mj"].append(float(np.mean(ep_energy)) if ep_energy else 0.0) metrics["latency_ms"].append(float(np.mean(ep_latency)) if ep_latency else 0.0) metrics["switches"].append(info["switches"]) return metrics def evaluate_all(self, trained_agent) -> Dict[str, Dict]: results = {} def always_pim(state, env): return 0 results["Always-PIM"] = self._run_policy(always_pim, "Always-PIM") def threshold_rule(state, env): T = env.sensor.T_current V_th = (env.sensor.voltage_history[-1] if env.sensor.voltage_history else 0.6) target = self.baseline.route_threshold_rule(env.current_profile, T, V_th) return {"PIM": 0, "CPU": 1, "GPU": 2}[target] results["Threshold-Rule"] = self._run_policy(threshold_rule, "Threshold-Rule") def complexity_only(state, env): target = self.baseline.route_complexity_only(env.current_profile) return {"PIM": 0, "CPU": 1, "GPU": 2}[target] results["Complexity-Only"] = self._run_policy(complexity_only, "Complexity-Only") def readys_route(state, env): target = self.readys.route(env.current_profile, sensor=env.sensor) return {"PIM": 0, "CPU": 1, "GPU": 2}[target] results["READYS"] = self._run_policy(readys_route, "READYS") def rl_agent(state, env): return trained_agent.select_action( state, sensor=env.sensor, task_profile=env.current_profile, training=False) results["RL-Agent (ours)"] = self._run_policy(rl_agent, "RL-Agent (ours)") return results def print_comparison_table(self, results: Dict[str, Dict]) -> None: print("\n" + "=" * 78) print(" BASELINE COMPARISON TABLE") print("=" * 78) header = f" {'Method':<22} {'Avg Reward':>12} {'Avg Energy(mJ)':>16} {'Avg Latency(ms)':>16} {'PIM%':>7}" print(header) print(" " + "-" * 74) for label, m in results.items(): total = sum(m["counts"].values()) pim_pct = m["counts"]["PIM"] / total * 100 if total else 0 print(f" {label:<22} " f"{np.mean(m['rewards']):>12.2f} " f"{np.mean(m['energy_mj']):>16.4f} " f"{np.mean(m['latency_ms']):>16.4f} " f"{pim_pct:>7.1f}%") print("=" * 78) # ═══════════════════════════════════════════════════════════════════════════════ # Ablation Study Framework # ═══════════════════════════════════════════════════════════════════════════════ class AblationStudy: """Systematically removes one component at a time.""" def __init__(self, num_episodes: int = 150, max_steps: int = 200, device: str = "cpu"): self.num_episodes = num_episodes self.max_steps = max_steps self.device = device def _train_variant(self, variant_name: str, use_dueling: bool = True, use_per: bool = True, use_safety_tier: bool = True, state_dim: int = 16) -> Tuple[float, float]: from rl_env import ComplexityAwarePIMEnv from rl_agent import ComplexityAwareRLAgent, PrioritizedReplayBuffer, Transition env = ComplexityAwarePIMEnv(max_steps=self.max_steps) agent = ComplexityAwareRLAgent( state_dim=state_dim, device=self.device, buffer_size=20000, batch_size=64) if not use_dueling: class FlatDQN(nn.Module): def __init__(self, state_dim, action_dim=3, hidden_dim=256): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, action_dim), ) def forward(self, x): return self.net(x) agent.policy_net = FlatDQN(state_dim).to(torch.device(self.device)) agent.target_net = FlatDQN(state_dim).to(torch.device(self.device)) agent.target_net.load_state_dict(agent.policy_net.state_dict()) agent.optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=5e-4) if not use_per: class UniformBuffer: def __init__(self, capacity=20000): self.buf = []; self.capacity = capacity; self.pos = 0; self.size = 0 def push(self, *args): if self.size < self.capacity: self.buf.append(Transition(*args)); self.size += 1 else: self.buf[self.pos] = Transition(*args) self.pos = (self.pos + 1) % self.capacity def sample(self, batch_size, beta=0.4): idxs = np.random.choice(self.size, batch_size, replace=False) samples = [self.buf[i] for i in idxs] weights = torch.ones(batch_size) return Transition(*zip(*samples)), idxs, weights def update_priorities(self, indices, td_errors): pass def __len__(self): return self.size agent.memory = UniformBuffer(20000) rewards, switches = [], [] for ep in range(self.num_episodes): state = env.reset() if state_dim == 8: state = state[:8] total_r = 0 for _ in range(self.max_steps): if use_safety_tier: action = agent.select_action( state, sensor=env.sensor, task_profile=env.current_profile) else: action = agent.select_action(state, training=True) next_state, reward, done, info = env.step(action) if state_dim == 8: next_state = next_state[:8] agent.store_transition(state, action, reward, next_state, float(done)) agent.train_step() total_r += reward state = next_state if done: break rewards.append(total_r) switches.append(info["switches"]) last50 = rewards[-50:] if len(rewards) >= 50 else rewards last50_sw = switches[-50:] if len(switches) >= 50 else switches return float(np.mean(last50)), float(np.mean(last50_sw)) def run(self) -> Dict[str, Dict]: print("\n--- Ablation Study ---") results = {} variants = [ ("Full system", True, True, True, 16), ("No dueling (flat DQN)", False, True, True, 16), ("No PER (uniform replay)", True, False, True, 16), ("No 3-tier hierarchy", True, True, False, 16), ("Physics-only state (8D)", True, True, True, 8), ] for name, dueling, per, safety, sdim in variants: print(f" Training variant: {name}...") r, sw = self._train_variant(name, dueling, per, safety, sdim) results[name] = {"mean_reward": r, "mean_switches": sw} print(f" → mean_reward={r:.2f}, mean_switches={sw:.1f}") print("\n ABLATION RESULTS (last-50-episode averages):") print(f" {'Variant':<35} {'Mean Reward':>13} {'Mean Switches':>14}") print(" " + "-" * 62) baseline_r = results["Full system"]["mean_reward"] for name, m in results.items(): drop = baseline_r - m["mean_reward"] drop_str = f" (−{drop:.2f})" if drop > 0.1 else "" print(f" {name:<35} {m['mean_reward']:>13.2f}{drop_str}") return results