| """ |
| controller.py |
| ============= |
| Complexity-aware PIM controller — seamless task router. |
| """ |
|
|
| import warnings |
| from typing import Dict, List, Optional, Tuple |
| import numpy as np |
|
|
| try: |
| import torch |
| import torch.nn as nn |
| HAS_TORCH = True |
| except ImportError: |
| HAS_TORCH = False |
|
|
| from profiler import TaskComplexityProfile, TaskComplexityProfiler |
| from physics import PhysicsSensorModel |
| from rl_env import ComplexityAwarePIMEnv |
| from rl_agent import ComplexityAwareRLAgent |
|
|
|
|
| class ComplexityAwarePIMController: |
| def __init__(self, device: str = "cpu", agent_path: Optional[str] = None): |
| self.profiler = TaskComplexityProfiler() |
| self.sensor = PhysicsSensorModel() |
| self.agent = ComplexityAwareRLAgent(device=device) |
| if agent_path: |
| self.agent.load(agent_path) |
|
|
| self.current_target = "PIM" |
| self.access_log: List[Dict] = [] |
| self.decision_log: List[Dict] = [] |
| self.task_log: List[TaskComplexityProfile] = [] |
| self.counts = {"PIM": 0, "CPU": 0, "GPU": 0} |
| self.total_energy_mj = 0.0 |
| self.total_latency_ms = 0.0 |
|
|
| def route_model(self, model: nn.Module, input_shape: Tuple, |
| timesteps: int = 1, |
| sample_input: Optional[torch.Tensor] = None) -> str: |
| profile = self.profiler.profile(model, input_shape, timesteps, sample_input) |
| return self.route_task(profile) |
|
|
| def route_task(self, profile: TaskComplexityProfile, |
| workload_override: Optional[float] = None) -> str: |
| workload = workload_override if workload_override is not None else \ |
| self._estimate_workload(profile) |
| spike_rate = 1.0 - profile.input_sparsity |
|
|
| self.sensor.update_temperature(workload, self.current_target) |
|
|
| physics = self._get_physics_state(workload, spike_rate) |
| complexity = self._get_complexity_state(profile) |
| state = np.concatenate([physics, complexity]) |
|
|
| action = self.agent.select_action( |
| state, sensor=self.sensor, |
| task_profile=profile, training=False) |
| target = ComplexityAwarePIMEnv.ACTION_NAMES[action] |
|
|
| if target == "PIM": |
| self.sensor.record_write(1) |
|
|
| est_latency = self.profiler.estimate_latency(profile, target) |
| est_energy = self.profiler.estimate_energy(profile, target) |
| self.total_latency_ms += est_latency |
| self.total_energy_mj += est_energy |
|
|
| self.decision_log.append({ |
| "step": len(self.decision_log), |
| "target": target, |
| "task_class": profile.complexity_class, |
| "flops": profile.total_flops, |
| "memory": profile.total_memory_bytes, |
| "ai": profile.arithmetic_intensity, |
| "temperature": self.sensor.T_current, |
| "temperature_case": self.sensor.T_case, |
| "reliability": self.sensor.get_thermal_reliability(), |
| "retention_time_s": self.sensor.get_retention_time(), |
| "est_latency_ms": est_latency, |
| "est_energy_mj": est_energy, |
| }) |
| self.task_log.append(profile) |
| self.current_target = target |
| self.counts[target] += 1 |
|
|
| return target |
|
|
| def access(self, addr, pim_op=False, workload_intensity=0.5, |
| spike_rate=0.1): |
| default_profile = TaskComplexityProfile( |
| total_flops=workload_intensity * 1e8, |
| total_memory_bytes=workload_intensity * 1e6, |
| input_sparsity=1.0 - spike_rate, |
| arithmetic_intensity=workload_intensity * 50, |
| is_memory_bound=(workload_intensity < 0.4), |
| complexity_class="LIGHT" if workload_intensity < 0.3 else |
| ("MEDIUM" if workload_intensity < 0.7 else "HEAVY"), |
| ) |
| target = self.route_task(default_profile, workload_override=workload_intensity) |
| self.access_log.append({'addr': addr, 'pim': pim_op, 'target': target}) |
| return target |
|
|
| def get_stats(self) -> Dict: |
| total = sum(self.counts.values()) |
| if total == 0: |
| return {"PIM": 0, "CPU": 0, "GPU": 0} |
| return { |
| "PIM_pct": self.counts["PIM"] / total, |
| "CPU_pct": self.counts["CPU"] / total, |
| "GPU_pct": self.counts["GPU"] / total, |
| "total_accesses": total, |
| "total_energy_mJ": self.total_energy_mj, |
| "total_latency_ms": self.total_latency_ms, |
| "avg_energy_per_task_mJ": self.total_energy_mj / total, |
| "avg_latency_per_task_ms": self.total_latency_ms / total, |
| } |
|
|
| def _estimate_workload(self, prof): |
| flops_f = np.clip(np.log10(max(prof.total_flops, 1)) / 10, 0, 1) |
| mem_f = np.clip(np.log10(max(prof.total_memory_bytes, 1)) / 8, 0, 1) |
| ts_f = np.clip(prof.timesteps / 100, 0, 1) |
| return np.clip(0.4*flops_f + 0.3*mem_f + 0.3*ts_f, 0, 1) |
|
|
| def _get_physics_state(self, workload, spike_rate): |
| V_th = self.sensor.get_threshold_voltage() |
| fd = self.sensor.get_fault_density() |
| rm = self.sensor.get_read_margin() |
| rel = self.sensor.get_thermal_reliability() |
| enc = {"PIM": 0.0, "CPU": 0.5, "GPU": 1.0} |
| return np.array([ |
| np.clip((self.sensor.T_current - 20)/80, 0, 1), |
| np.clip((V_th - 0.3)/0.6, 0, 1), |
| np.clip(fd/0.1, 0, 1), rm, |
| np.clip(workload, 0, 1), np.clip(spike_rate, 0, 1), |
| enc.get(self.current_target, 0.0), rel, |
| ], dtype=np.float32) |
|
|
| def _get_complexity_state(self, prof): |
| flops_n = np.clip(np.log10(max(prof.total_flops, 1))/10, 0, 1) |
| mem_n = np.clip(np.log10(max(prof.total_memory_bytes, 1))/8, 0, 1) |
| ai_n = np.clip(prof.arithmetic_intensity / 100, 0, 1) |
| sp_n = prof.input_sparsity |
| depth = np.clip((prof.num_conv_layers + prof.num_linear_layers)/10, 0, 1) |
| ts_n = np.clip(prof.timesteps / 100, 0, 1) |
| cls = {"LIGHT":0.0, "MEDIUM":0.5, "HEAVY":1.0} |
| mb = 1.0 if prof.is_memory_bound else 0.0 |
| return np.array([flops_n, mem_n, ai_n, sp_n, depth, ts_n, |
| cls.get(prof.complexity_class, 0.5), mb], dtype=np.float32) |
|
|