tempo-snn-v2 / src /baselines.py
KD099's picture
Upload folder using huggingface_hub
a157e36 verified
"""
baselines.py
============
Deterministic baseline routers and RL-based baselines.
Literature:
- READYS: Grinsztajn et al. (IEEE Cluster 2021)
- EdgeSched-DQN: ScienceDirect 2025
- Das et al. (DAC 2014) β€” thermal optimization
- Lee, Shin, Chwa (ACM TECS 2019) β€” thermal-aware scheduling
"""
import random
import math
from typing import Dict, List, Optional, Tuple
from collections import Counter
import numpy as np
try:
import torch
import torch.nn as nn
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
from profiler import TaskComplexityProfile, TaskComplexityProfiler
from rl_env import ComplexityAwarePIMEnv
class BaselineRouter:
"""Three deterministic baselines."""
def __init__(self):
self.profiler = TaskComplexityProfiler()
def route_always_pim(self, profile: TaskComplexityProfile) -> str:
return "PIM"
def route_threshold_rule(self, profile: TaskComplexityProfile,
T: float, V_th: float) -> str:
if V_th > 0.85:
return "CPU"
if T > 85.0:
return "GPU"
if profile.complexity_class == "HEAVY":
return "GPU"
if profile.is_memory_bound and profile.complexity_class == "LIGHT":
return "PIM"
return "CPU"
def route_complexity_only(self, profile: TaskComplexityProfile) -> str:
scores = self.profiler.compute_suitability_scores(profile)
return max(scores, key=scores.get)
def route_standard_dqn(self, state: np.ndarray, policy_net) -> int:
with torch.no_grad():
state_t = torch.FloatTensor(state).unsqueeze(0)
q_values = policy_net(state_t).cpu().numpy()[0]
return int(np.argmax(q_values))
# ═══════════════════════════════════════════════════════════════════════════════
# READYS-style Greedy Scheduler (Grinsztajn et al. 2021)
# ═══════════════════════════════════════════════════════════════════════════════
class READYSRouter:
"""
READYS-inspired greedy heuristic:
score = deadline_slack / estimated_exec_time, pick highest.
Adapted to our 3-target discrete setting.
"""
def __init__(self):
self.profiler = TaskComplexityProfiler()
def route(self, profile: TaskComplexityProfile,
sensor=None,
deadline_ms: float = 100.0) -> str:
est = {}
for t in ["PIM", "CPU", "GPU"]:
est[t] = self.profiler.estimate_latency(profile, t)
scores = {}
for t in ["PIM", "CPU", "GPU"]:
slack = deadline_ms - est[t]
scores[t] = max(slack, 0.01) / max(est[t], 0.001)
# Safety overrides
if sensor:
if getattr(sensor, 'T_current', 25.0) > 85.0:
return "GPU"
if (hasattr(sensor, 'voltage_history') and sensor.voltage_history and
sensor.voltage_history[-1] > 0.85):
return "CPU"
return max(scores, key=scores.get)
# ═══════════════════════════════════════════════════════════════════════════════
# EdgeSched-DQN style Flat DQN Baseline
# ═══════════════════════════════════════════════════════════════════════════════
class FlatDQN(nn.Module):
"""Standard (non-dueling) DQN with state+task size inputs."""
def __init__(self, state_dim=16, action_dim=3, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.ReLU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, action_dim),
)
def forward(self, x):
return self.net(x)
class EdgeSchedDQNAgent:
"""
Flat DQN baseline matching EdgeSched-DQN architecture.
No dueling, no PER, no 3-tier hierarchy.
"""
def __init__(self, state_dim=16, action_dim=3, hidden_dim=256,
lr=5e-4, gamma=0.99, tau=0.005, buffer_size=50000,
batch_size=128, device="cpu"):
if not HAS_TORCH:
raise RuntimeError("PyTorch required.")
self.device = torch.device(device)
self.gamma = gamma
self.tau = tau
self.batch_size = batch_size
self.steps_done = 0
self.policy_net = FlatDQN(state_dim, action_dim, hidden_dim).to(self.device)
self.target_net = FlatDQN(state_dim, action_dim, hidden_dim).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=lr)
from rl_agent import PrioritizedReplayBuffer
self.memory = PrioritizedReplayBuffer(buffer_size, device=device)
self.action_dim = action_dim
def select_action(self, state: np.ndarray, epsilon: float = 0.0) -> int:
if random.random() < epsilon:
return random.randrange(self.action_dim)
with torch.no_grad():
q = self.policy_net(torch.FloatTensor(state).unsqueeze(0).to(self.device))
return int(q.argmax(dim=1).item())
def store_transition(self, *args):
self.memory.push(*args)
def train_step(self):
if len(self.memory) < self.batch_size:
return None
states, actions, rewards, next_states, dones, indices, weights = \
self.memory.sample(self.batch_size)
current_q = self.policy_net(states).gather(1, actions).squeeze()
with torch.no_grad():
next_q = self.target_net(next_states).max(dim=1)[0]
target_q = rewards + (1 - dones) * self.gamma * next_q
td_errors = (current_q - target_q).detach().cpu().numpy()
self.memory.update_priorities(indices, td_errors)
loss = (weights * nn.functional.smooth_l1_loss(
current_q, target_q, reduction='none')).mean()
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10.0)
self.optimizer.step()
for tp, pp in zip(self.target_net.parameters(), self.policy_net.parameters()):
tp.data.copy_(self.tau * pp.data + (1 - self.tau) * tp.data)
return float(loss.item())
# ═══════════════════════════════════════════════════════════════════════════════
# Baseline Evaluator
# ═══════════════════════════════════════════════════════════════════════════════
class BaselineEvaluator:
ACTION_NAMES = {0: "PIM", 1: "CPU", 2: "GPU"}
def __init__(self, num_eval_episodes: int = 50, max_steps: int = 200):
self.num_eval_episodes = num_eval_episodes
self.max_steps = max_steps
self.baseline = BaselineRouter()
self.readys = READYSRouter()
def _run_policy(self, policy_fn, label: str) -> Dict:
env = ComplexityAwarePIMEnv(max_steps=self.max_steps)
metrics = {
"label": label, "rewards": [], "energy_mj": [],
"latency_ms": [], "counts": {"PIM": 0, "CPU": 0, "GPU": 0},
"switches": [],
}
for _ in range(self.num_eval_episodes):
state = env.reset()
total_r, ep_energy, ep_latency = 0.0, [], []
for _ in range(self.max_steps):
action = policy_fn(state, env)
state, reward, done, info = env.step(action)
target = self.ACTION_NAMES[action]
metrics["counts"][target] += 1
total_r += reward
prof = env.current_profile
ep_energy.append(env.profiler.estimate_energy(prof, target))
ep_latency.append(env.profiler.estimate_latency(prof, target))
if done:
break
metrics["rewards"].append(total_r)
metrics["energy_mj"].append(float(np.mean(ep_energy)) if ep_energy else 0.0)
metrics["latency_ms"].append(float(np.mean(ep_latency)) if ep_latency else 0.0)
metrics["switches"].append(info["switches"])
return metrics
def evaluate_all(self, trained_agent) -> Dict[str, Dict]:
results = {}
def always_pim(state, env): return 0
results["Always-PIM"] = self._run_policy(always_pim, "Always-PIM")
def threshold_rule(state, env):
T = env.sensor.T_current
V_th = (env.sensor.voltage_history[-1]
if env.sensor.voltage_history else 0.6)
target = self.baseline.route_threshold_rule(env.current_profile, T, V_th)
return {"PIM": 0, "CPU": 1, "GPU": 2}[target]
results["Threshold-Rule"] = self._run_policy(threshold_rule, "Threshold-Rule")
def complexity_only(state, env):
target = self.baseline.route_complexity_only(env.current_profile)
return {"PIM": 0, "CPU": 1, "GPU": 2}[target]
results["Complexity-Only"] = self._run_policy(complexity_only, "Complexity-Only")
def readys_route(state, env):
target = self.readys.route(env.current_profile, sensor=env.sensor)
return {"PIM": 0, "CPU": 1, "GPU": 2}[target]
results["READYS"] = self._run_policy(readys_route, "READYS")
def rl_agent(state, env):
return trained_agent.select_action(
state, sensor=env.sensor,
task_profile=env.current_profile, training=False)
results["RL-Agent (ours)"] = self._run_policy(rl_agent, "RL-Agent (ours)")
return results
def print_comparison_table(self, results: Dict[str, Dict]) -> None:
print("\n" + "=" * 78)
print(" BASELINE COMPARISON TABLE")
print("=" * 78)
header = f" {'Method':<22} {'Avg Reward':>12} {'Avg Energy(mJ)':>16} {'Avg Latency(ms)':>16} {'PIM%':>7}"
print(header)
print(" " + "-" * 74)
for label, m in results.items():
total = sum(m["counts"].values())
pim_pct = m["counts"]["PIM"] / total * 100 if total else 0
print(f" {label:<22} "
f"{np.mean(m['rewards']):>12.2f} "
f"{np.mean(m['energy_mj']):>16.4f} "
f"{np.mean(m['latency_ms']):>16.4f} "
f"{pim_pct:>7.1f}%")
print("=" * 78)
# ═══════════════════════════════════════════════════════════════════════════════
# Ablation Study Framework
# ═══════════════════════════════════════════════════════════════════════════════
class AblationStudy:
"""Systematically removes one component at a time."""
def __init__(self, num_episodes: int = 150, max_steps: int = 200,
device: str = "cpu"):
self.num_episodes = num_episodes
self.max_steps = max_steps
self.device = device
def _train_variant(self, variant_name: str,
use_dueling: bool = True,
use_per: bool = True,
use_safety_tier: bool = True,
state_dim: int = 16) -> Tuple[float, float]:
from rl_env import ComplexityAwarePIMEnv
from rl_agent import ComplexityAwareRLAgent, PrioritizedReplayBuffer, Transition
env = ComplexityAwarePIMEnv(max_steps=self.max_steps)
agent = ComplexityAwareRLAgent(
state_dim=state_dim, device=self.device,
buffer_size=20000, batch_size=64)
if not use_dueling:
class FlatDQN(nn.Module):
def __init__(self, state_dim, action_dim=3, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim), nn.ReLU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, action_dim),
)
def forward(self, x):
return self.net(x)
agent.policy_net = FlatDQN(state_dim).to(torch.device(self.device))
agent.target_net = FlatDQN(state_dim).to(torch.device(self.device))
agent.target_net.load_state_dict(agent.policy_net.state_dict())
agent.optimizer = torch.optim.Adam(agent.policy_net.parameters(), lr=5e-4)
if not use_per:
class UniformBuffer:
def __init__(self, capacity=20000):
self.buf = []; self.capacity = capacity; self.pos = 0; self.size = 0
def push(self, *args):
if self.size < self.capacity:
self.buf.append(Transition(*args)); self.size += 1
else:
self.buf[self.pos] = Transition(*args)
self.pos = (self.pos + 1) % self.capacity
def sample(self, batch_size, beta=0.4):
idxs = np.random.choice(self.size, batch_size, replace=False)
samples = [self.buf[i] for i in idxs]
weights = torch.ones(batch_size)
return Transition(*zip(*samples)), idxs, weights
def update_priorities(self, indices, td_errors): pass
def __len__(self): return self.size
agent.memory = UniformBuffer(20000)
rewards, switches = [], []
for ep in range(self.num_episodes):
state = env.reset()
if state_dim == 8:
state = state[:8]
total_r = 0
for _ in range(self.max_steps):
if use_safety_tier:
action = agent.select_action(
state, sensor=env.sensor,
task_profile=env.current_profile)
else:
action = agent.select_action(state, training=True)
next_state, reward, done, info = env.step(action)
if state_dim == 8:
next_state = next_state[:8]
agent.store_transition(state, action, reward, next_state, float(done))
agent.train_step()
total_r += reward
state = next_state
if done:
break
rewards.append(total_r)
switches.append(info["switches"])
last50 = rewards[-50:] if len(rewards) >= 50 else rewards
last50_sw = switches[-50:] if len(switches) >= 50 else switches
return float(np.mean(last50)), float(np.mean(last50_sw))
def run(self) -> Dict[str, Dict]:
print("\n--- Ablation Study ---")
results = {}
variants = [
("Full system", True, True, True, 16),
("No dueling (flat DQN)", False, True, True, 16),
("No PER (uniform replay)", True, False, True, 16),
("No 3-tier hierarchy", True, True, False, 16),
("Physics-only state (8D)", True, True, True, 8),
]
for name, dueling, per, safety, sdim in variants:
print(f" Training variant: {name}...")
r, sw = self._train_variant(name, dueling, per, safety, sdim)
results[name] = {"mean_reward": r, "mean_switches": sw}
print(f" β†’ mean_reward={r:.2f}, mean_switches={sw:.1f}")
print("\n ABLATION RESULTS (last-50-episode averages):")
print(f" {'Variant':<35} {'Mean Reward':>13} {'Mean Switches':>14}")
print(" " + "-" * 62)
baseline_r = results["Full system"]["mean_reward"]
for name, m in results.items():
drop = baseline_r - m["mean_reward"]
drop_str = f" (βˆ’{drop:.2f})" if drop > 0.1 else ""
print(f" {name:<35} {m['mean_reward']:>13.2f}{drop_str}")
return results