""" controller.py ============= Complexity-aware PIM controller — seamless task router. """ import warnings from typing import Dict, List, Optional, Tuple import numpy as np try: import torch import torch.nn as nn HAS_TORCH = True except ImportError: HAS_TORCH = False from profiler import TaskComplexityProfile, TaskComplexityProfiler from physics import PhysicsSensorModel from rl_env import ComplexityAwarePIMEnv from rl_agent import ComplexityAwareRLAgent class ComplexityAwarePIMController: def __init__(self, device: str = "cpu", agent_path: Optional[str] = None): self.profiler = TaskComplexityProfiler() self.sensor = PhysicsSensorModel() self.agent = ComplexityAwareRLAgent(device=device) if agent_path: self.agent.load(agent_path) self.current_target = "PIM" self.access_log: List[Dict] = [] self.decision_log: List[Dict] = [] self.task_log: List[TaskComplexityProfile] = [] self.counts = {"PIM": 0, "CPU": 0, "GPU": 0} self.total_energy_mj = 0.0 self.total_latency_ms = 0.0 def route_model(self, model: nn.Module, input_shape: Tuple, timesteps: int = 1, sample_input: Optional[torch.Tensor] = None) -> str: profile = self.profiler.profile(model, input_shape, timesteps, sample_input) return self.route_task(profile) def route_task(self, profile: TaskComplexityProfile, workload_override: Optional[float] = None) -> str: workload = workload_override if workload_override is not None else \ self._estimate_workload(profile) spike_rate = 1.0 - profile.input_sparsity self.sensor.update_temperature(workload, self.current_target) physics = self._get_physics_state(workload, spike_rate) complexity = self._get_complexity_state(profile) state = np.concatenate([physics, complexity]) action = self.agent.select_action( state, sensor=self.sensor, task_profile=profile, training=False) target = ComplexityAwarePIMEnv.ACTION_NAMES[action] if target == "PIM": self.sensor.record_write(1) est_latency = self.profiler.estimate_latency(profile, target) est_energy = self.profiler.estimate_energy(profile, target) self.total_latency_ms += est_latency self.total_energy_mj += est_energy self.decision_log.append({ "step": len(self.decision_log), "target": target, "task_class": profile.complexity_class, "flops": profile.total_flops, "memory": profile.total_memory_bytes, "ai": profile.arithmetic_intensity, "temperature": self.sensor.T_current, "temperature_case": self.sensor.T_case, "reliability": self.sensor.get_thermal_reliability(), "retention_time_s": self.sensor.get_retention_time(), "est_latency_ms": est_latency, "est_energy_mj": est_energy, }) self.task_log.append(profile) self.current_target = target self.counts[target] += 1 return target def access(self, addr, pim_op=False, workload_intensity=0.5, spike_rate=0.1): default_profile = TaskComplexityProfile( total_flops=workload_intensity * 1e8, total_memory_bytes=workload_intensity * 1e6, input_sparsity=1.0 - spike_rate, arithmetic_intensity=workload_intensity * 50, is_memory_bound=(workload_intensity < 0.4), complexity_class="LIGHT" if workload_intensity < 0.3 else ("MEDIUM" if workload_intensity < 0.7 else "HEAVY"), ) target = self.route_task(default_profile, workload_override=workload_intensity) self.access_log.append({'addr': addr, 'pim': pim_op, 'target': target}) return target def get_stats(self) -> Dict: total = sum(self.counts.values()) if total == 0: return {"PIM": 0, "CPU": 0, "GPU": 0} return { "PIM_pct": self.counts["PIM"] / total, "CPU_pct": self.counts["CPU"] / total, "GPU_pct": self.counts["GPU"] / total, "total_accesses": total, "total_energy_mJ": self.total_energy_mj, "total_latency_ms": self.total_latency_ms, "avg_energy_per_task_mJ": self.total_energy_mj / total, "avg_latency_per_task_ms": self.total_latency_ms / total, } def _estimate_workload(self, prof): flops_f = np.clip(np.log10(max(prof.total_flops, 1)) / 10, 0, 1) mem_f = np.clip(np.log10(max(prof.total_memory_bytes, 1)) / 8, 0, 1) ts_f = np.clip(prof.timesteps / 100, 0, 1) return np.clip(0.4*flops_f + 0.3*mem_f + 0.3*ts_f, 0, 1) def _get_physics_state(self, workload, spike_rate): V_th = self.sensor.get_threshold_voltage() fd = self.sensor.get_fault_density() rm = self.sensor.get_read_margin() rel = self.sensor.get_thermal_reliability() enc = {"PIM": 0.0, "CPU": 0.5, "GPU": 1.0} return np.array([ np.clip((self.sensor.T_current - 20)/80, 0, 1), np.clip((V_th - 0.3)/0.6, 0, 1), np.clip(fd/0.1, 0, 1), rm, np.clip(workload, 0, 1), np.clip(spike_rate, 0, 1), enc.get(self.current_target, 0.0), rel, ], dtype=np.float32) def _get_complexity_state(self, prof): flops_n = np.clip(np.log10(max(prof.total_flops, 1))/10, 0, 1) mem_n = np.clip(np.log10(max(prof.total_memory_bytes, 1))/8, 0, 1) ai_n = np.clip(prof.arithmetic_intensity / 100, 0, 1) sp_n = prof.input_sparsity depth = np.clip((prof.num_conv_layers + prof.num_linear_layers)/10, 0, 1) ts_n = np.clip(prof.timesteps / 100, 0, 1) cls = {"LIGHT":0.0, "MEDIUM":0.5, "HEAVY":1.0} mb = 1.0 if prof.is_memory_bound else 0.0 return np.array([flops_n, mem_n, ai_n, sp_n, depth, ts_n, cls.get(prof.complexity_class, 0.5), mb], dtype=np.float32)