tempo-snn-v2 / src /controller.py
KD099's picture
Upload folder using huggingface_hub
a157e36 verified
"""
controller.py
=============
Complexity-aware PIM controller — seamless task router.
"""
import warnings
from typing import Dict, List, Optional, Tuple
import numpy as np
try:
import torch
import torch.nn as nn
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
from profiler import TaskComplexityProfile, TaskComplexityProfiler
from physics import PhysicsSensorModel
from rl_env import ComplexityAwarePIMEnv
from rl_agent import ComplexityAwareRLAgent
class ComplexityAwarePIMController:
def __init__(self, device: str = "cpu", agent_path: Optional[str] = None):
self.profiler = TaskComplexityProfiler()
self.sensor = PhysicsSensorModel()
self.agent = ComplexityAwareRLAgent(device=device)
if agent_path:
self.agent.load(agent_path)
self.current_target = "PIM"
self.access_log: List[Dict] = []
self.decision_log: List[Dict] = []
self.task_log: List[TaskComplexityProfile] = []
self.counts = {"PIM": 0, "CPU": 0, "GPU": 0}
self.total_energy_mj = 0.0
self.total_latency_ms = 0.0
def route_model(self, model: nn.Module, input_shape: Tuple,
timesteps: int = 1,
sample_input: Optional[torch.Tensor] = None) -> str:
profile = self.profiler.profile(model, input_shape, timesteps, sample_input)
return self.route_task(profile)
def route_task(self, profile: TaskComplexityProfile,
workload_override: Optional[float] = None) -> str:
workload = workload_override if workload_override is not None else \
self._estimate_workload(profile)
spike_rate = 1.0 - profile.input_sparsity
self.sensor.update_temperature(workload, self.current_target)
physics = self._get_physics_state(workload, spike_rate)
complexity = self._get_complexity_state(profile)
state = np.concatenate([physics, complexity])
action = self.agent.select_action(
state, sensor=self.sensor,
task_profile=profile, training=False)
target = ComplexityAwarePIMEnv.ACTION_NAMES[action]
if target == "PIM":
self.sensor.record_write(1)
est_latency = self.profiler.estimate_latency(profile, target)
est_energy = self.profiler.estimate_energy(profile, target)
self.total_latency_ms += est_latency
self.total_energy_mj += est_energy
self.decision_log.append({
"step": len(self.decision_log),
"target": target,
"task_class": profile.complexity_class,
"flops": profile.total_flops,
"memory": profile.total_memory_bytes,
"ai": profile.arithmetic_intensity,
"temperature": self.sensor.T_current,
"temperature_case": self.sensor.T_case,
"reliability": self.sensor.get_thermal_reliability(),
"retention_time_s": self.sensor.get_retention_time(),
"est_latency_ms": est_latency,
"est_energy_mj": est_energy,
})
self.task_log.append(profile)
self.current_target = target
self.counts[target] += 1
return target
def access(self, addr, pim_op=False, workload_intensity=0.5,
spike_rate=0.1):
default_profile = TaskComplexityProfile(
total_flops=workload_intensity * 1e8,
total_memory_bytes=workload_intensity * 1e6,
input_sparsity=1.0 - spike_rate,
arithmetic_intensity=workload_intensity * 50,
is_memory_bound=(workload_intensity < 0.4),
complexity_class="LIGHT" if workload_intensity < 0.3 else
("MEDIUM" if workload_intensity < 0.7 else "HEAVY"),
)
target = self.route_task(default_profile, workload_override=workload_intensity)
self.access_log.append({'addr': addr, 'pim': pim_op, 'target': target})
return target
def get_stats(self) -> Dict:
total = sum(self.counts.values())
if total == 0:
return {"PIM": 0, "CPU": 0, "GPU": 0}
return {
"PIM_pct": self.counts["PIM"] / total,
"CPU_pct": self.counts["CPU"] / total,
"GPU_pct": self.counts["GPU"] / total,
"total_accesses": total,
"total_energy_mJ": self.total_energy_mj,
"total_latency_ms": self.total_latency_ms,
"avg_energy_per_task_mJ": self.total_energy_mj / total,
"avg_latency_per_task_ms": self.total_latency_ms / total,
}
def _estimate_workload(self, prof):
flops_f = np.clip(np.log10(max(prof.total_flops, 1)) / 10, 0, 1)
mem_f = np.clip(np.log10(max(prof.total_memory_bytes, 1)) / 8, 0, 1)
ts_f = np.clip(prof.timesteps / 100, 0, 1)
return np.clip(0.4*flops_f + 0.3*mem_f + 0.3*ts_f, 0, 1)
def _get_physics_state(self, workload, spike_rate):
V_th = self.sensor.get_threshold_voltage()
fd = self.sensor.get_fault_density()
rm = self.sensor.get_read_margin()
rel = self.sensor.get_thermal_reliability()
enc = {"PIM": 0.0, "CPU": 0.5, "GPU": 1.0}
return np.array([
np.clip((self.sensor.T_current - 20)/80, 0, 1),
np.clip((V_th - 0.3)/0.6, 0, 1),
np.clip(fd/0.1, 0, 1), rm,
np.clip(workload, 0, 1), np.clip(spike_rate, 0, 1),
enc.get(self.current_target, 0.0), rel,
], dtype=np.float32)
def _get_complexity_state(self, prof):
flops_n = np.clip(np.log10(max(prof.total_flops, 1))/10, 0, 1)
mem_n = np.clip(np.log10(max(prof.total_memory_bytes, 1))/8, 0, 1)
ai_n = np.clip(prof.arithmetic_intensity / 100, 0, 1)
sp_n = prof.input_sparsity
depth = np.clip((prof.num_conv_layers + prof.num_linear_layers)/10, 0, 1)
ts_n = np.clip(prof.timesteps / 100, 0, 1)
cls = {"LIGHT":0.0, "MEDIUM":0.5, "HEAVY":1.0}
mb = 1.0 if prof.is_memory_bound else 0.0
return np.array([flops_n, mem_n, ai_n, sp_n, depth, ts_n,
cls.get(prof.complexity_class, 0.5), mb], dtype=np.float32)