KD099 commited on
Commit
97e6e2b
·
verified ·
1 Parent(s): ca3a9ab

Upload src/hw_in_loop.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/hw_in_loop.py +202 -0
src/hw_in_loop.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ hw_in_loop.py
3
+ =============
4
+ Hardware-in-the-loop (HIL) stub interface (G-5).
5
+ Provides a generic socket-based bridge to real PIM/CPU/GPU hardware measurements.
6
+
7
+ Architecture:
8
+ - HILClient: connects to a local HIL daemon (TCP/Unix socket)
9
+ - HILDaemon: stub server that replays pre-recorded traces or forwards to real driver
10
+ - Measures: latency_ms, energy_mJ, power_mW, temperature_C, accuracy_drop
11
+ """
12
+
13
+ import json
14
+ import time
15
+ import warnings
16
+ from dataclasses import dataclass, asdict
17
+ from typing import Dict, Optional, Callable
18
+ import numpy as np
19
+
20
+ try:
21
+ import torch
22
+ import torch.nn as nn
23
+ HAS_TORCH = True
24
+ except ImportError:
25
+ HAS_TORCH = False
26
+
27
+
28
+ @dataclass
29
+ class HILMeasurement:
30
+ target: str
31
+ latency_ms: float
32
+ energy_mJ: float
33
+ power_mW: float
34
+ temperature_C: float
35
+ accuracy_drop: float = 0.0
36
+ timestamp: float = 0.0
37
+ raw: Dict = None
38
+
39
+ def __post_init__(self):
40
+ if self.raw is None:
41
+ self.raw = {}
42
+
43
+
44
+ class HILClient:
45
+ """
46
+ G-5: Hardware-in-the-loop measurement client.
47
+ Connects to a HIL daemon to get real (or trace-replayed) measurements.
48
+ """
49
+
50
+ def __init__(self, host: str = "127.0.0.1", port: int = 9876,
51
+ timeout_s: float = 5.0,
52
+ trace_file: Optional[str] = None):
53
+ self.host = host
54
+ self.port = port
55
+ self.timeout_s = timeout_s
56
+ self.sock = None
57
+ self._connected = False
58
+ self._trace: Dict[str, list] = {} # target -> list of HILMeasurement
59
+ if trace_file:
60
+ self._load_trace(trace_file)
61
+
62
+ def _load_trace(self, path: str):
63
+ try:
64
+ with open(path) as f:
65
+ data = json.load(f)
66
+ for entry in data:
67
+ t = entry["target"]
68
+ self._trace.setdefault(t, []).append(HILMeasurement(**entry))
69
+ except FileNotFoundError:
70
+ warnings.warn(f"Trace file {path} not found — using dummy traces.")
71
+
72
+ def connect(self) -> bool:
73
+ """Attempt TCP connection to HIL daemon."""
74
+ try:
75
+ import socket
76
+ self.sock = socket.create_connection((self.host, self.port),
77
+ timeout=self.timeout_s)
78
+ self._connected = True
79
+ return True
80
+ except Exception as exc:
81
+ warnings.warn(f"HIL connection failed: {exc} — falling back to trace/dummy.")
82
+ self._connected = False
83
+ return False
84
+
85
+ def measure(self, model: "nn.Module", target: str,
86
+ input_shape: tuple, timesteps: int = 1) -> HILMeasurement:
87
+ """
88
+ G-5: Execute model on target hardware (or replay trace) and return measurement.
89
+ """
90
+ if self._connected and self.sock:
91
+ req = {
92
+ "target": target,
93
+ "model_hash": str(hash(str(model))),
94
+ "input_shape": input_shape,
95
+ "timesteps": timesteps,
96
+ }
97
+ try:
98
+ import socket
99
+ self.sock.sendall((json.dumps(req) + "\n").encode())
100
+ resp = self.sock.recv(4096).decode()
101
+ data = json.loads(resp)
102
+ return HILMeasurement(**data)
103
+ except Exception as exc:
104
+ warnings.warn(f"HIL measure failed: {exc} — using fallback.")
105
+
106
+ # Fallback 1: trace replay
107
+ if target in self._trace and self._trace[target]:
108
+ return self._trace[target].pop(0)
109
+
110
+ # Fallback 2: dummy deterministic measurement
111
+ return self._dummy_measurement(target, input_shape, timesteps)
112
+
113
+ def _dummy_measurement(self, target: str, input_shape: tuple, timesteps: int) -> HILMeasurement:
114
+ # Deterministic dummy based on target + workload size
115
+ base = np.prod(input_shape) * timesteps
116
+ if target == "PIM":
117
+ lat = 0.5 + base / 1e6 * 0.1
118
+ pwr = 30.0
119
+ elif target == "CPU":
120
+ lat = 2.0 + base / 1e6 * 0.5
121
+ pwr = 8000.0
122
+ else:
123
+ lat = 0.8 + base / 1e6 * 0.05
124
+ pwr = 150000.0
125
+ return HILMeasurement(
126
+ target=target,
127
+ latency_ms=lat,
128
+ energy_mJ=pwr * lat / 1000.0,
129
+ power_mW=pwr,
130
+ temperature_C=35.0 + np.random.normal(0, 2),
131
+ accuracy_drop=0.0,
132
+ timestamp=time.time(),
133
+ )
134
+
135
+ def close(self):
136
+ if self.sock:
137
+ self.sock.close()
138
+ self.sock = None
139
+ self._connected = False
140
+
141
+ def __enter__(self):
142
+ self.connect()
143
+ return self
144
+
145
+ def __exit__(self, *args):
146
+ self.close()
147
+
148
+
149
+ class HILDaemon:
150
+ """
151
+ G-5: Stub HIL daemon that accepts TCP connections and either
152
+ (a) forwards to a real driver shim, or (b) replays traces.
153
+ This is a server — run in a separate process or thread.
154
+ """
155
+
156
+ def __init__(self, host: str = "127.0.0.1", port: int = 9876,
157
+ driver_fn: Optional[Callable] = None,
158
+ trace_file: Optional[str] = None):
159
+ self.host = host
160
+ self.port = port
161
+ self.driver_fn = driver_fn
162
+ self._trace: Dict[str, list] = {}
163
+ if trace_file:
164
+ self._load_trace(trace_file)
165
+
166
+ def _load_trace(self, path: str):
167
+ with open(path) as f:
168
+ data = json.load(f)
169
+ for entry in data:
170
+ t = entry["target"]
171
+ self._trace.setdefault(t, []).append(HILMeasurement(**entry))
172
+
173
+ def _handle_request(self, req: dict) -> dict:
174
+ target = req.get("target", "CPU")
175
+ if self.driver_fn:
176
+ return asdict(self.driver_fn(req))
177
+ if target in self._trace and self._trace[target]:
178
+ return asdict(self._trace[target].pop(0))
179
+ # Dummy
180
+ m = HILMeasurement(
181
+ target=target, latency_ms=1.0, energy_mJ=1.0,
182
+ power_mW=1000.0, temperature_C=40.0)
183
+ return asdict(m)
184
+
185
+ def run(self):
186
+ import socket
187
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
188
+ s.bind((self.host, self.port))
189
+ s.listen(1)
190
+ print(f"[HILDaemon] listening on {self.host}:{self.port}")
191
+ while True:
192
+ conn, addr = s.accept()
193
+ with conn:
194
+ data = conn.recv(4096)
195
+ if not data:
196
+ continue
197
+ try:
198
+ req = json.loads(data.decode())
199
+ resp = self._handle_request(req)
200
+ conn.sendall((json.dumps(resp) + "\n").encode())
201
+ except Exception as exc:
202
+ conn.sendall((json.dumps({"error": str(exc)}) + "\n").encode())