File size: 12,626 Bytes
9908537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""Deploy a trained SHD model to the Neurocore SDK or evaluate quantization.



Loads a PyTorch checkpoint from shd_train.py, quantizes weights to int16,

and evaluates accuracy with quantized weights. Also builds an SDK Network

for deployment to the FPGA via CUBA neurons.



Supports both LIF and adLIF checkpoints. For adLIF, adaptation parameters

(rho, beta_a) are training-only; only alpha (membrane decay) deploys as decay_v.



Usage:

    python shd_deploy.py --checkpoint shd_model.pt --data-dir data/shd

    python shd_deploy.py --checkpoint shd_adlif_model.pt --neuron-type adlif

"""

import os
import sys
import argparse
import numpy as np

import torch
from torch.utils.data import DataLoader

# Add SDK and benchmarks to path
_SDK_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
if _SDK_DIR not in sys.path:
    sys.path.insert(0, _SDK_DIR)
sys.path.insert(0, os.path.dirname(__file__))

from shd_loader import SHDDataset, collate_fn, N_CHANNELS, N_CLASSES
from shd_train import SHDSNN

from neurocore import Network
from neurocore.constants import WEIGHT_MIN, WEIGHT_MAX


def quantize_weights(w_float, threshold_float, threshold_hw=1000):
    """Quantize float weight matrix to int16 for hardware deployment.



    Maps float weights so hardware dynamics match training dynamics:

        weight_hw = round(w_float * threshold_hw / threshold_float)

        clamped to [WEIGHT_MIN, WEIGHT_MAX] = [-32768, 32767]



    Args:

        w_float: (out, in) float32 weight matrix from nn.Linear

        threshold_float: threshold used in training (e.g. 1.0)

        threshold_hw: hardware threshold (default 1000)



    Returns:

        w_int: (in, out) int32 weight matrix (transposed for src->tgt convention)

    """
    scale = threshold_hw / threshold_float
    w_scaled = w_float * scale
    w_int = np.clip(np.round(w_scaled), WEIGHT_MIN, WEIGHT_MAX).astype(np.int32)
    # nn.Linear stores (out, in), SDK wants (src, tgt) = (in, out)
    return w_int.T


def detect_neuron_type(checkpoint):
    """Auto-detect neuron type from checkpoint state dict keys."""
    state = checkpoint['model_state_dict']
    if 'lif1.alpha_raw' in state:
        return 'adlif'
    return 'lif'


def compute_hardware_params(checkpoint, threshold_hw=1000, neuron_type=None):
    """Compute hardware neuron parameters from trained model.



    Maps membrane decay to CUBA neuron decay_v:

        decay_v = round(decay * 4096)  (12-bit fractional)



    For LIF: decay = beta (from lif1.beta_raw)

    For adLIF: decay = alpha (from lif1.alpha_raw)

    adLIF adaptation params (rho, beta_a) are training-only.



    Returns:

        dict with hardware parameters for each layer

    """
    state = checkpoint['model_state_dict']
    if neuron_type is None:
        neuron_type = detect_neuron_type(checkpoint)

    params = {'neuron_type': neuron_type}

    if neuron_type == 'adlif':
        # Hidden layer: alpha is membrane decay
        alpha_raw = state.get('lif1.alpha_raw', None)
        if alpha_raw is not None:
            alpha = torch.sigmoid(alpha_raw).cpu().numpy()
            params['hidden_alpha_mean'] = float(alpha.mean())
            params['hidden_alpha_std'] = float(alpha.std())
            params['hidden_decay_v'] = int(round(alpha.mean() * 4096))
            # For backward compat with build_sdk_network
            params['hidden_beta_mean'] = float(alpha.mean())

        # Log training-only adaptation params
        rho_raw = state.get('lif1.rho_raw', None)
        if rho_raw is not None:
            rho = torch.sigmoid(rho_raw).cpu().numpy()
            params['hidden_rho_mean'] = float(rho.mean())
            params['hidden_rho_note'] = 'training-only (not deployed)'

        beta_a_raw = state.get('lif1.beta_a_raw', None)
        if beta_a_raw is not None:
            import torch.nn.functional as F_
            beta_a = F_.softplus(beta_a_raw).cpu().numpy()
            params['hidden_beta_a_mean'] = float(beta_a.mean())
            params['hidden_beta_a_note'] = 'training-only (not deployed)'
    else:
        # LIF: beta is membrane decay
        beta_hid_raw = state.get('lif1.beta_raw', None)
        if beta_hid_raw is not None:
            beta_hid = torch.sigmoid(beta_hid_raw).cpu().numpy()
            params['hidden_beta_mean'] = float(beta_hid.mean())
            params['hidden_beta_std'] = float(beta_hid.std())
            params['hidden_decay_v'] = int(round(beta_hid.mean() * 4096))

    # Output layer is always standard LIF
    beta_out_raw = state.get('lif2.beta_raw', None)
    if beta_out_raw is not None:
        beta_out = torch.sigmoid(beta_out_raw).cpu().numpy()
        params['output_beta_mean'] = float(beta_out.mean())
        params['output_beta_std'] = float(beta_out.std())
        params['output_decay_v'] = int(round(beta_out.mean() * 4096))

    params['threshold_hw'] = threshold_hw
    return params


def build_sdk_network(checkpoint, threshold_hw=1000):
    """Build SDK Network from a trained PyTorch checkpoint.



    Uses subtractive leak as approximation for multiplicative decay.

    True hardware deployment would use CUBA mode with decay_v.



    Returns:

        net: Network ready for deploy()

        n_hidden: hidden layer size (for reporting)

    """
    args = checkpoint['args']
    threshold_float = args['threshold']
    n_hidden = args['hidden']

    state = checkpoint['model_state_dict']
    w_fc1 = state['fc1.weight'].cpu().numpy()
    w_fc2 = state['fc2.weight'].cpu().numpy()
    w_rec = state['fc_rec.weight'].cpu().numpy()

    # Quantize
    wm_fc1 = quantize_weights(w_fc1, threshold_float, threshold_hw)
    wm_fc2 = quantize_weights(w_fc2, threshold_float, threshold_hw)
    wm_rec = quantize_weights(w_rec, threshold_float, threshold_hw)

    # Approximate decay as subtractive leak (for SDK Simulator compatibility)
    hw = compute_hardware_params(checkpoint, threshold_hw)
    leak_hid = max(1, int(round((1 - hw.get('hidden_beta_mean', 0.95)) * threshold_hw)))
    leak_out = max(1, int(round((1 - hw.get('output_beta_mean', 0.9)) * threshold_hw)))

    # Build network
    net = Network()
    inp = net.population(N_CHANNELS,
                         params={'threshold': 65535, 'leak': 0, 'refrac': 0},
                         label="input")
    hid = net.population(n_hidden,
                         params={'threshold': threshold_hw, 'leak': leak_hid, 'refrac': 0},
                         label="hidden")
    out = net.population(N_CLASSES,
                         params={'threshold': threshold_hw, 'leak': leak_out, 'refrac': 0},
                         label="output")

    net.connect(inp, hid, weight_matrix=wm_fc1)
    net.connect(hid, out, weight_matrix=wm_fc2)
    net.connect(hid, hid, weight_matrix=wm_rec)

    # Report stats
    nonzero_fc1 = np.count_nonzero(wm_fc1)
    nonzero_fc2 = np.count_nonzero(wm_fc2)
    nonzero_rec = np.count_nonzero(wm_rec)
    total_conn = nonzero_fc1 + nonzero_fc2 + nonzero_rec
    print(f"Quantized weights (threshold_hw={threshold_hw}):")
    print(f"  fc1: {wm_fc1.shape}, {nonzero_fc1:,} nonzero, "
          f"range [{wm_fc1.min()}, {wm_fc1.max()}]")
    print(f"  fc2: {wm_fc2.shape}, {nonzero_fc2:,} nonzero, "
          f"range [{wm_fc2.min()}, {wm_fc2.max()}]")
    print(f"  rec: {wm_rec.shape}, {nonzero_rec:,} nonzero, "
          f"range [{wm_rec.min()}, {wm_rec.max()}]")
    print(f"  Total connections: {total_conn:,}")
    if 'hidden_decay_v' in hw:
        print(f"  Hardware decay_v (hidden): {hw['hidden_decay_v']} "
              f"(beta={hw['hidden_beta_mean']:.4f})")
    if 'output_decay_v' in hw:
        print(f"  Hardware decay_v (output): {hw['output_decay_v']} "
              f"(beta={hw['output_beta_mean']:.4f})")

    return net, n_hidden


def run_pytorch_quantized_inference(checkpoint, test_ds, device='cpu',

                                     neuron_type=None):
    """Run inference with quantized weights in PyTorch (for comparison).



    Loads the model, replaces float weights with quantized int versions

    (converted back to float), and runs normal forward pass.

    """
    args = checkpoint['args']
    threshold_float = args['threshold']
    threshold_hw = 1000
    if neuron_type is None:
        neuron_type = args.get('neuron_type', detect_neuron_type(checkpoint))

    model = SHDSNN(
        n_hidden=args['hidden'],
        threshold=args['threshold'],
        beta_hidden=args.get('beta_hidden', 0.95),
        beta_out=args.get('beta_out', 0.9),
        dropout=0.0,  # no dropout at inference
        neuron_type=neuron_type,
        alpha_init=args.get('alpha_init', 0.90),
        rho_init=args.get('rho_init', 0.85),
        beta_a_init=args.get('beta_a_init', 1.8),
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Quantize and de-quantize weights to simulate quantization error
    scale = threshold_hw / threshold_float
    skip_keys = ('beta', 'alpha', 'rho', 'threshold_base')
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'weight' in name and not any(k in name for k in skip_keys):
                q = torch.round(param * scale).clamp(WEIGHT_MIN, WEIGHT_MAX) / scale
                param.copy_(q)

    model.eval()
    loader = DataLoader(test_ds, batch_size=128, shuffle=False,
                        collate_fn=collate_fn, num_workers=0)

    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            output = model(inputs)
            correct += (output.argmax(1) == labels).sum().item()
            total += inputs.size(0)

    acc = correct / total
    print(f"  PyTorch quantized accuracy: {correct}/{total} = {acc*100:.1f}%")
    return acc


def main():
    parser = argparse.ArgumentParser(description="Deploy trained SHD model")
    parser.add_argument("--checkpoint", default="shd_model.pt",
                        help="Path to trained model checkpoint")
    parser.add_argument("--data-dir", default="data/shd")
    parser.add_argument("--n-samples", type=int, default=None,
                        help="Limit test samples (default: all)")
    parser.add_argument("--threshold-hw", type=int, default=1000)
    parser.add_argument("--dt", type=float, default=4e-3)
    parser.add_argument("--neuron-type", choices=["lif", "adlif"], default=None,
                        help="Neuron model (auto-detected from checkpoint if omitted)")
    args = parser.parse_args()

    print(f"Loading checkpoint: {args.checkpoint}")
    ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
    train_args = ckpt['args']

    # Auto-detect neuron type if not specified
    neuron_type = args.neuron_type or train_args.get('neuron_type', detect_neuron_type(ckpt))
    print(f"  Training accuracy: {ckpt['test_acc']*100:.1f}%")
    print(f"  Architecture: {N_CHANNELS}->{train_args['hidden']}->{N_CLASSES} ({neuron_type.upper()})")

    print("\nLoading test dataset...")
    test_ds = SHDDataset(args.data_dir, "test", dt=args.dt)
    print(f"  {len(test_ds)} samples, {test_ds.n_bins} time bins")

    # 1. Hardware parameter mapping
    print("\n--- Hardware parameter mapping ---")
    hw_params = compute_hardware_params(ckpt, args.threshold_hw, neuron_type)
    for k, v in sorted(hw_params.items()):
        print(f"  {k}: {v}")

    # 2. PyTorch quantized inference (weight quantization impact)
    print("\n--- PyTorch quantized inference ---")
    pytorch_acc = run_pytorch_quantized_inference(ckpt, test_ds,
                                                   neuron_type=neuron_type)

    # 3. Build SDK network (for reference)
    print("\n--- SDK network summary ---")
    net, n_hidden = build_sdk_network(ckpt, threshold_hw=args.threshold_hw)

    # Summary
    print("\n=== Results ===")
    print(f"  PyTorch float accuracy:     {ckpt['test_acc']*100:.1f}%")
    print(f"  PyTorch quantized accuracy: {pytorch_acc*100:.1f}%")
    gap = abs(ckpt['test_acc'] - pytorch_acc) * 100
    print(f"  Quantization loss:          {gap:.1f}%")
    print(f"\n  Hardware deployment: CUBA mode (decay_v={hw_params.get('hidden_decay_v', 'N/A')})")
    print(f"  Total synapses: {sum(1 for c in net.connections for _ in range(1)):,}")


if __name__ == "__main__":
    main()