mrwabbit commited on
Commit
9908537
·
verified ·
1 Parent(s): faada08

Upload shd_deploy.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. shd_deploy.py +303 -0
shd_deploy.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deploy a trained SHD model to the Neurocore SDK or evaluate quantization.
2
+
3
+ Loads a PyTorch checkpoint from shd_train.py, quantizes weights to int16,
4
+ and evaluates accuracy with quantized weights. Also builds an SDK Network
5
+ for deployment to the FPGA via CUBA neurons.
6
+
7
+ Supports both LIF and adLIF checkpoints. For adLIF, adaptation parameters
8
+ (rho, beta_a) are training-only; only alpha (membrane decay) deploys as decay_v.
9
+
10
+ Usage:
11
+ python shd_deploy.py --checkpoint shd_model.pt --data-dir data/shd
12
+ python shd_deploy.py --checkpoint shd_adlif_model.pt --neuron-type adlif
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import argparse
18
+ import numpy as np
19
+
20
+ import torch
21
+ from torch.utils.data import DataLoader
22
+
23
+ # Add SDK and benchmarks to path
24
+ _SDK_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
25
+ if _SDK_DIR not in sys.path:
26
+ sys.path.insert(0, _SDK_DIR)
27
+ sys.path.insert(0, os.path.dirname(__file__))
28
+
29
+ from shd_loader import SHDDataset, collate_fn, N_CHANNELS, N_CLASSES
30
+ from shd_train import SHDSNN
31
+
32
+ from neurocore import Network
33
+ from neurocore.constants import WEIGHT_MIN, WEIGHT_MAX
34
+
35
+
36
+ def quantize_weights(w_float, threshold_float, threshold_hw=1000):
37
+ """Quantize float weight matrix to int16 for hardware deployment.
38
+
39
+ Maps float weights so hardware dynamics match training dynamics:
40
+ weight_hw = round(w_float * threshold_hw / threshold_float)
41
+ clamped to [WEIGHT_MIN, WEIGHT_MAX] = [-32768, 32767]
42
+
43
+ Args:
44
+ w_float: (out, in) float32 weight matrix from nn.Linear
45
+ threshold_float: threshold used in training (e.g. 1.0)
46
+ threshold_hw: hardware threshold (default 1000)
47
+
48
+ Returns:
49
+ w_int: (in, out) int32 weight matrix (transposed for src->tgt convention)
50
+ """
51
+ scale = threshold_hw / threshold_float
52
+ w_scaled = w_float * scale
53
+ w_int = np.clip(np.round(w_scaled), WEIGHT_MIN, WEIGHT_MAX).astype(np.int32)
54
+ # nn.Linear stores (out, in), SDK wants (src, tgt) = (in, out)
55
+ return w_int.T
56
+
57
+
58
+ def detect_neuron_type(checkpoint):
59
+ """Auto-detect neuron type from checkpoint state dict keys."""
60
+ state = checkpoint['model_state_dict']
61
+ if 'lif1.alpha_raw' in state:
62
+ return 'adlif'
63
+ return 'lif'
64
+
65
+
66
+ def compute_hardware_params(checkpoint, threshold_hw=1000, neuron_type=None):
67
+ """Compute hardware neuron parameters from trained model.
68
+
69
+ Maps membrane decay to CUBA neuron decay_v:
70
+ decay_v = round(decay * 4096) (12-bit fractional)
71
+
72
+ For LIF: decay = beta (from lif1.beta_raw)
73
+ For adLIF: decay = alpha (from lif1.alpha_raw)
74
+ adLIF adaptation params (rho, beta_a) are training-only.
75
+
76
+ Returns:
77
+ dict with hardware parameters for each layer
78
+ """
79
+ state = checkpoint['model_state_dict']
80
+ if neuron_type is None:
81
+ neuron_type = detect_neuron_type(checkpoint)
82
+
83
+ params = {'neuron_type': neuron_type}
84
+
85
+ if neuron_type == 'adlif':
86
+ # Hidden layer: alpha is membrane decay
87
+ alpha_raw = state.get('lif1.alpha_raw', None)
88
+ if alpha_raw is not None:
89
+ alpha = torch.sigmoid(alpha_raw).cpu().numpy()
90
+ params['hidden_alpha_mean'] = float(alpha.mean())
91
+ params['hidden_alpha_std'] = float(alpha.std())
92
+ params['hidden_decay_v'] = int(round(alpha.mean() * 4096))
93
+ # For backward compat with build_sdk_network
94
+ params['hidden_beta_mean'] = float(alpha.mean())
95
+
96
+ # Log training-only adaptation params
97
+ rho_raw = state.get('lif1.rho_raw', None)
98
+ if rho_raw is not None:
99
+ rho = torch.sigmoid(rho_raw).cpu().numpy()
100
+ params['hidden_rho_mean'] = float(rho.mean())
101
+ params['hidden_rho_note'] = 'training-only (not deployed)'
102
+
103
+ beta_a_raw = state.get('lif1.beta_a_raw', None)
104
+ if beta_a_raw is not None:
105
+ import torch.nn.functional as F_
106
+ beta_a = F_.softplus(beta_a_raw).cpu().numpy()
107
+ params['hidden_beta_a_mean'] = float(beta_a.mean())
108
+ params['hidden_beta_a_note'] = 'training-only (not deployed)'
109
+ else:
110
+ # LIF: beta is membrane decay
111
+ beta_hid_raw = state.get('lif1.beta_raw', None)
112
+ if beta_hid_raw is not None:
113
+ beta_hid = torch.sigmoid(beta_hid_raw).cpu().numpy()
114
+ params['hidden_beta_mean'] = float(beta_hid.mean())
115
+ params['hidden_beta_std'] = float(beta_hid.std())
116
+ params['hidden_decay_v'] = int(round(beta_hid.mean() * 4096))
117
+
118
+ # Output layer is always standard LIF
119
+ beta_out_raw = state.get('lif2.beta_raw', None)
120
+ if beta_out_raw is not None:
121
+ beta_out = torch.sigmoid(beta_out_raw).cpu().numpy()
122
+ params['output_beta_mean'] = float(beta_out.mean())
123
+ params['output_beta_std'] = float(beta_out.std())
124
+ params['output_decay_v'] = int(round(beta_out.mean() * 4096))
125
+
126
+ params['threshold_hw'] = threshold_hw
127
+ return params
128
+
129
+
130
+ def build_sdk_network(checkpoint, threshold_hw=1000):
131
+ """Build SDK Network from a trained PyTorch checkpoint.
132
+
133
+ Uses subtractive leak as approximation for multiplicative decay.
134
+ True hardware deployment would use CUBA mode with decay_v.
135
+
136
+ Returns:
137
+ net: Network ready for deploy()
138
+ n_hidden: hidden layer size (for reporting)
139
+ """
140
+ args = checkpoint['args']
141
+ threshold_float = args['threshold']
142
+ n_hidden = args['hidden']
143
+
144
+ state = checkpoint['model_state_dict']
145
+ w_fc1 = state['fc1.weight'].cpu().numpy()
146
+ w_fc2 = state['fc2.weight'].cpu().numpy()
147
+ w_rec = state['fc_rec.weight'].cpu().numpy()
148
+
149
+ # Quantize
150
+ wm_fc1 = quantize_weights(w_fc1, threshold_float, threshold_hw)
151
+ wm_fc2 = quantize_weights(w_fc2, threshold_float, threshold_hw)
152
+ wm_rec = quantize_weights(w_rec, threshold_float, threshold_hw)
153
+
154
+ # Approximate decay as subtractive leak (for SDK Simulator compatibility)
155
+ hw = compute_hardware_params(checkpoint, threshold_hw)
156
+ leak_hid = max(1, int(round((1 - hw.get('hidden_beta_mean', 0.95)) * threshold_hw)))
157
+ leak_out = max(1, int(round((1 - hw.get('output_beta_mean', 0.9)) * threshold_hw)))
158
+
159
+ # Build network
160
+ net = Network()
161
+ inp = net.population(N_CHANNELS,
162
+ params={'threshold': 65535, 'leak': 0, 'refrac': 0},
163
+ label="input")
164
+ hid = net.population(n_hidden,
165
+ params={'threshold': threshold_hw, 'leak': leak_hid, 'refrac': 0},
166
+ label="hidden")
167
+ out = net.population(N_CLASSES,
168
+ params={'threshold': threshold_hw, 'leak': leak_out, 'refrac': 0},
169
+ label="output")
170
+
171
+ net.connect(inp, hid, weight_matrix=wm_fc1)
172
+ net.connect(hid, out, weight_matrix=wm_fc2)
173
+ net.connect(hid, hid, weight_matrix=wm_rec)
174
+
175
+ # Report stats
176
+ nonzero_fc1 = np.count_nonzero(wm_fc1)
177
+ nonzero_fc2 = np.count_nonzero(wm_fc2)
178
+ nonzero_rec = np.count_nonzero(wm_rec)
179
+ total_conn = nonzero_fc1 + nonzero_fc2 + nonzero_rec
180
+ print(f"Quantized weights (threshold_hw={threshold_hw}):")
181
+ print(f" fc1: {wm_fc1.shape}, {nonzero_fc1:,} nonzero, "
182
+ f"range [{wm_fc1.min()}, {wm_fc1.max()}]")
183
+ print(f" fc2: {wm_fc2.shape}, {nonzero_fc2:,} nonzero, "
184
+ f"range [{wm_fc2.min()}, {wm_fc2.max()}]")
185
+ print(f" rec: {wm_rec.shape}, {nonzero_rec:,} nonzero, "
186
+ f"range [{wm_rec.min()}, {wm_rec.max()}]")
187
+ print(f" Total connections: {total_conn:,}")
188
+ if 'hidden_decay_v' in hw:
189
+ print(f" Hardware decay_v (hidden): {hw['hidden_decay_v']} "
190
+ f"(beta={hw['hidden_beta_mean']:.4f})")
191
+ if 'output_decay_v' in hw:
192
+ print(f" Hardware decay_v (output): {hw['output_decay_v']} "
193
+ f"(beta={hw['output_beta_mean']:.4f})")
194
+
195
+ return net, n_hidden
196
+
197
+
198
+ def run_pytorch_quantized_inference(checkpoint, test_ds, device='cpu',
199
+ neuron_type=None):
200
+ """Run inference with quantized weights in PyTorch (for comparison).
201
+
202
+ Loads the model, replaces float weights with quantized int versions
203
+ (converted back to float), and runs normal forward pass.
204
+ """
205
+ args = checkpoint['args']
206
+ threshold_float = args['threshold']
207
+ threshold_hw = 1000
208
+ if neuron_type is None:
209
+ neuron_type = args.get('neuron_type', detect_neuron_type(checkpoint))
210
+
211
+ model = SHDSNN(
212
+ n_hidden=args['hidden'],
213
+ threshold=args['threshold'],
214
+ beta_hidden=args.get('beta_hidden', 0.95),
215
+ beta_out=args.get('beta_out', 0.9),
216
+ dropout=0.0, # no dropout at inference
217
+ neuron_type=neuron_type,
218
+ alpha_init=args.get('alpha_init', 0.90),
219
+ rho_init=args.get('rho_init', 0.85),
220
+ beta_a_init=args.get('beta_a_init', 1.8),
221
+ ).to(device)
222
+ model.load_state_dict(checkpoint['model_state_dict'])
223
+
224
+ # Quantize and de-quantize weights to simulate quantization error
225
+ scale = threshold_hw / threshold_float
226
+ skip_keys = ('beta', 'alpha', 'rho', 'threshold_base')
227
+ with torch.no_grad():
228
+ for name, param in model.named_parameters():
229
+ if 'weight' in name and not any(k in name for k in skip_keys):
230
+ q = torch.round(param * scale).clamp(WEIGHT_MIN, WEIGHT_MAX) / scale
231
+ param.copy_(q)
232
+
233
+ model.eval()
234
+ loader = DataLoader(test_ds, batch_size=128, shuffle=False,
235
+ collate_fn=collate_fn, num_workers=0)
236
+
237
+ correct = 0
238
+ total = 0
239
+ with torch.no_grad():
240
+ for inputs, labels in loader:
241
+ inputs, labels = inputs.to(device), labels.to(device)
242
+ output = model(inputs)
243
+ correct += (output.argmax(1) == labels).sum().item()
244
+ total += inputs.size(0)
245
+
246
+ acc = correct / total
247
+ print(f" PyTorch quantized accuracy: {correct}/{total} = {acc*100:.1f}%")
248
+ return acc
249
+
250
+
251
+ def main():
252
+ parser = argparse.ArgumentParser(description="Deploy trained SHD model")
253
+ parser.add_argument("--checkpoint", default="shd_model.pt",
254
+ help="Path to trained model checkpoint")
255
+ parser.add_argument("--data-dir", default="data/shd")
256
+ parser.add_argument("--n-samples", type=int, default=None,
257
+ help="Limit test samples (default: all)")
258
+ parser.add_argument("--threshold-hw", type=int, default=1000)
259
+ parser.add_argument("--dt", type=float, default=4e-3)
260
+ parser.add_argument("--neuron-type", choices=["lif", "adlif"], default=None,
261
+ help="Neuron model (auto-detected from checkpoint if omitted)")
262
+ args = parser.parse_args()
263
+
264
+ print(f"Loading checkpoint: {args.checkpoint}")
265
+ ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
266
+ train_args = ckpt['args']
267
+
268
+ # Auto-detect neuron type if not specified
269
+ neuron_type = args.neuron_type or train_args.get('neuron_type', detect_neuron_type(ckpt))
270
+ print(f" Training accuracy: {ckpt['test_acc']*100:.1f}%")
271
+ print(f" Architecture: {N_CHANNELS}->{train_args['hidden']}->{N_CLASSES} ({neuron_type.upper()})")
272
+
273
+ print("\nLoading test dataset...")
274
+ test_ds = SHDDataset(args.data_dir, "test", dt=args.dt)
275
+ print(f" {len(test_ds)} samples, {test_ds.n_bins} time bins")
276
+
277
+ # 1. Hardware parameter mapping
278
+ print("\n--- Hardware parameter mapping ---")
279
+ hw_params = compute_hardware_params(ckpt, args.threshold_hw, neuron_type)
280
+ for k, v in sorted(hw_params.items()):
281
+ print(f" {k}: {v}")
282
+
283
+ # 2. PyTorch quantized inference (weight quantization impact)
284
+ print("\n--- PyTorch quantized inference ---")
285
+ pytorch_acc = run_pytorch_quantized_inference(ckpt, test_ds,
286
+ neuron_type=neuron_type)
287
+
288
+ # 3. Build SDK network (for reference)
289
+ print("\n--- SDK network summary ---")
290
+ net, n_hidden = build_sdk_network(ckpt, threshold_hw=args.threshold_hw)
291
+
292
+ # Summary
293
+ print("\n=== Results ===")
294
+ print(f" PyTorch float accuracy: {ckpt['test_acc']*100:.1f}%")
295
+ print(f" PyTorch quantized accuracy: {pytorch_acc*100:.1f}%")
296
+ gap = abs(ckpt['test_acc'] - pytorch_acc) * 100
297
+ print(f" Quantization loss: {gap:.1f}%")
298
+ print(f"\n Hardware deployment: CUBA mode (decay_v={hw_params.get('hidden_decay_v', 'N/A')})")
299
+ print(f" Total synapses: {sum(1 for c in net.connections for _ in range(1)):,}")
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()