File size: 6,187 Bytes
a157e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
controller.py
=============
Complexity-aware PIM controller — seamless task router.
"""

import warnings
from typing import Dict, List, Optional, Tuple
import numpy as np

try:
    import torch
    import torch.nn as nn
    HAS_TORCH = True
except ImportError:
    HAS_TORCH = False

from profiler import TaskComplexityProfile, TaskComplexityProfiler
from physics import PhysicsSensorModel
from rl_env import ComplexityAwarePIMEnv
from rl_agent import ComplexityAwareRLAgent


class ComplexityAwarePIMController:
    def __init__(self, device: str = "cpu", agent_path: Optional[str] = None):
        self.profiler = TaskComplexityProfiler()
        self.sensor = PhysicsSensorModel()
        self.agent = ComplexityAwareRLAgent(device=device)
        if agent_path:
            self.agent.load(agent_path)

        self.current_target = "PIM"
        self.access_log: List[Dict] = []
        self.decision_log: List[Dict] = []
        self.task_log: List[TaskComplexityProfile] = []
        self.counts = {"PIM": 0, "CPU": 0, "GPU": 0}
        self.total_energy_mj = 0.0
        self.total_latency_ms = 0.0

    def route_model(self, model: nn.Module, input_shape: Tuple,
                    timesteps: int = 1,
                    sample_input: Optional[torch.Tensor] = None) -> str:
        profile = self.profiler.profile(model, input_shape, timesteps, sample_input)
        return self.route_task(profile)

    def route_task(self, profile: TaskComplexityProfile,
                   workload_override: Optional[float] = None) -> str:
        workload = workload_override if workload_override is not None else \
            self._estimate_workload(profile)
        spike_rate = 1.0 - profile.input_sparsity

        self.sensor.update_temperature(workload, self.current_target)

        physics = self._get_physics_state(workload, spike_rate)
        complexity = self._get_complexity_state(profile)
        state = np.concatenate([physics, complexity])

        action = self.agent.select_action(
            state, sensor=self.sensor,
            task_profile=profile, training=False)
        target = ComplexityAwarePIMEnv.ACTION_NAMES[action]

        if target == "PIM":
            self.sensor.record_write(1)

        est_latency = self.profiler.estimate_latency(profile, target)
        est_energy = self.profiler.estimate_energy(profile, target)
        self.total_latency_ms += est_latency
        self.total_energy_mj += est_energy

        self.decision_log.append({
            "step": len(self.decision_log),
            "target": target,
            "task_class": profile.complexity_class,
            "flops": profile.total_flops,
            "memory": profile.total_memory_bytes,
            "ai": profile.arithmetic_intensity,
            "temperature": self.sensor.T_current,
            "temperature_case": self.sensor.T_case,
            "reliability": self.sensor.get_thermal_reliability(),
            "retention_time_s": self.sensor.get_retention_time(),
            "est_latency_ms": est_latency,
            "est_energy_mj": est_energy,
        })
        self.task_log.append(profile)
        self.current_target = target
        self.counts[target] += 1

        return target

    def access(self, addr, pim_op=False, workload_intensity=0.5,
               spike_rate=0.1):
        default_profile = TaskComplexityProfile(
            total_flops=workload_intensity * 1e8,
            total_memory_bytes=workload_intensity * 1e6,
            input_sparsity=1.0 - spike_rate,
            arithmetic_intensity=workload_intensity * 50,
            is_memory_bound=(workload_intensity < 0.4),
            complexity_class="LIGHT" if workload_intensity < 0.3 else
                            ("MEDIUM" if workload_intensity < 0.7 else "HEAVY"),
        )
        target = self.route_task(default_profile, workload_override=workload_intensity)
        self.access_log.append({'addr': addr, 'pim': pim_op, 'target': target})
        return target

    def get_stats(self) -> Dict:
        total = sum(self.counts.values())
        if total == 0:
            return {"PIM": 0, "CPU": 0, "GPU": 0}
        return {
            "PIM_pct": self.counts["PIM"] / total,
            "CPU_pct": self.counts["CPU"] / total,
            "GPU_pct": self.counts["GPU"] / total,
            "total_accesses": total,
            "total_energy_mJ": self.total_energy_mj,
            "total_latency_ms": self.total_latency_ms,
            "avg_energy_per_task_mJ": self.total_energy_mj / total,
            "avg_latency_per_task_ms": self.total_latency_ms / total,
        }

    def _estimate_workload(self, prof):
        flops_f = np.clip(np.log10(max(prof.total_flops, 1)) / 10, 0, 1)
        mem_f = np.clip(np.log10(max(prof.total_memory_bytes, 1)) / 8, 0, 1)
        ts_f = np.clip(prof.timesteps / 100, 0, 1)
        return np.clip(0.4*flops_f + 0.3*mem_f + 0.3*ts_f, 0, 1)

    def _get_physics_state(self, workload, spike_rate):
        V_th = self.sensor.get_threshold_voltage()
        fd = self.sensor.get_fault_density()
        rm = self.sensor.get_read_margin()
        rel = self.sensor.get_thermal_reliability()
        enc = {"PIM": 0.0, "CPU": 0.5, "GPU": 1.0}
        return np.array([
            np.clip((self.sensor.T_current - 20)/80, 0, 1),
            np.clip((V_th - 0.3)/0.6, 0, 1),
            np.clip(fd/0.1, 0, 1), rm,
            np.clip(workload, 0, 1), np.clip(spike_rate, 0, 1),
            enc.get(self.current_target, 0.0), rel,
        ], dtype=np.float32)

    def _get_complexity_state(self, prof):
        flops_n = np.clip(np.log10(max(prof.total_flops, 1))/10, 0, 1)
        mem_n = np.clip(np.log10(max(prof.total_memory_bytes, 1))/8, 0, 1)
        ai_n = np.clip(prof.arithmetic_intensity / 100, 0, 1)
        sp_n = prof.input_sparsity
        depth = np.clip((prof.num_conv_layers + prof.num_linear_layers)/10, 0, 1)
        ts_n = np.clip(prof.timesteps / 100, 0, 1)
        cls = {"LIGHT":0.0, "MEDIUM":0.5, "HEAVY":1.0}
        mb = 1.0 if prof.is_memory_bound else 0.0
        return np.array([flops_n, mem_n, ai_n, sp_n, depth, ts_n,
                         cls.get(prof.complexity_class, 0.5), mb], dtype=np.float32)