File size: 6,187 Bytes
a157e36 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """
controller.py
=============
Complexity-aware PIM controller — seamless task router.
"""
import warnings
from typing import Dict, List, Optional, Tuple
import numpy as np
try:
import torch
import torch.nn as nn
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
from profiler import TaskComplexityProfile, TaskComplexityProfiler
from physics import PhysicsSensorModel
from rl_env import ComplexityAwarePIMEnv
from rl_agent import ComplexityAwareRLAgent
class ComplexityAwarePIMController:
def __init__(self, device: str = "cpu", agent_path: Optional[str] = None):
self.profiler = TaskComplexityProfiler()
self.sensor = PhysicsSensorModel()
self.agent = ComplexityAwareRLAgent(device=device)
if agent_path:
self.agent.load(agent_path)
self.current_target = "PIM"
self.access_log: List[Dict] = []
self.decision_log: List[Dict] = []
self.task_log: List[TaskComplexityProfile] = []
self.counts = {"PIM": 0, "CPU": 0, "GPU": 0}
self.total_energy_mj = 0.0
self.total_latency_ms = 0.0
def route_model(self, model: nn.Module, input_shape: Tuple,
timesteps: int = 1,
sample_input: Optional[torch.Tensor] = None) -> str:
profile = self.profiler.profile(model, input_shape, timesteps, sample_input)
return self.route_task(profile)
def route_task(self, profile: TaskComplexityProfile,
workload_override: Optional[float] = None) -> str:
workload = workload_override if workload_override is not None else \
self._estimate_workload(profile)
spike_rate = 1.0 - profile.input_sparsity
self.sensor.update_temperature(workload, self.current_target)
physics = self._get_physics_state(workload, spike_rate)
complexity = self._get_complexity_state(profile)
state = np.concatenate([physics, complexity])
action = self.agent.select_action(
state, sensor=self.sensor,
task_profile=profile, training=False)
target = ComplexityAwarePIMEnv.ACTION_NAMES[action]
if target == "PIM":
self.sensor.record_write(1)
est_latency = self.profiler.estimate_latency(profile, target)
est_energy = self.profiler.estimate_energy(profile, target)
self.total_latency_ms += est_latency
self.total_energy_mj += est_energy
self.decision_log.append({
"step": len(self.decision_log),
"target": target,
"task_class": profile.complexity_class,
"flops": profile.total_flops,
"memory": profile.total_memory_bytes,
"ai": profile.arithmetic_intensity,
"temperature": self.sensor.T_current,
"temperature_case": self.sensor.T_case,
"reliability": self.sensor.get_thermal_reliability(),
"retention_time_s": self.sensor.get_retention_time(),
"est_latency_ms": est_latency,
"est_energy_mj": est_energy,
})
self.task_log.append(profile)
self.current_target = target
self.counts[target] += 1
return target
def access(self, addr, pim_op=False, workload_intensity=0.5,
spike_rate=0.1):
default_profile = TaskComplexityProfile(
total_flops=workload_intensity * 1e8,
total_memory_bytes=workload_intensity * 1e6,
input_sparsity=1.0 - spike_rate,
arithmetic_intensity=workload_intensity * 50,
is_memory_bound=(workload_intensity < 0.4),
complexity_class="LIGHT" if workload_intensity < 0.3 else
("MEDIUM" if workload_intensity < 0.7 else "HEAVY"),
)
target = self.route_task(default_profile, workload_override=workload_intensity)
self.access_log.append({'addr': addr, 'pim': pim_op, 'target': target})
return target
def get_stats(self) -> Dict:
total = sum(self.counts.values())
if total == 0:
return {"PIM": 0, "CPU": 0, "GPU": 0}
return {
"PIM_pct": self.counts["PIM"] / total,
"CPU_pct": self.counts["CPU"] / total,
"GPU_pct": self.counts["GPU"] / total,
"total_accesses": total,
"total_energy_mJ": self.total_energy_mj,
"total_latency_ms": self.total_latency_ms,
"avg_energy_per_task_mJ": self.total_energy_mj / total,
"avg_latency_per_task_ms": self.total_latency_ms / total,
}
def _estimate_workload(self, prof):
flops_f = np.clip(np.log10(max(prof.total_flops, 1)) / 10, 0, 1)
mem_f = np.clip(np.log10(max(prof.total_memory_bytes, 1)) / 8, 0, 1)
ts_f = np.clip(prof.timesteps / 100, 0, 1)
return np.clip(0.4*flops_f + 0.3*mem_f + 0.3*ts_f, 0, 1)
def _get_physics_state(self, workload, spike_rate):
V_th = self.sensor.get_threshold_voltage()
fd = self.sensor.get_fault_density()
rm = self.sensor.get_read_margin()
rel = self.sensor.get_thermal_reliability()
enc = {"PIM": 0.0, "CPU": 0.5, "GPU": 1.0}
return np.array([
np.clip((self.sensor.T_current - 20)/80, 0, 1),
np.clip((V_th - 0.3)/0.6, 0, 1),
np.clip(fd/0.1, 0, 1), rm,
np.clip(workload, 0, 1), np.clip(spike_rate, 0, 1),
enc.get(self.current_target, 0.0), rel,
], dtype=np.float32)
def _get_complexity_state(self, prof):
flops_n = np.clip(np.log10(max(prof.total_flops, 1))/10, 0, 1)
mem_n = np.clip(np.log10(max(prof.total_memory_bytes, 1))/8, 0, 1)
ai_n = np.clip(prof.arithmetic_intensity / 100, 0, 1)
sp_n = prof.input_sparsity
depth = np.clip((prof.num_conv_layers + prof.num_linear_layers)/10, 0, 1)
ts_n = np.clip(prof.timesteps / 100, 0, 1)
cls = {"LIGHT":0.0, "MEDIUM":0.5, "HEAVY":1.0}
mb = 1.0 if prof.is_memory_bound else 0.0
return np.array([flops_n, mem_n, ai_n, sp_n, depth, ts_n,
cls.get(prof.complexity_class, 0.5), mb], dtype=np.float32)
|