Upload shd_deploy.py with huggingface_hub
Browse files- shd_deploy.py +303 -0
shd_deploy.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Deploy a trained SHD model to the Neurocore SDK or evaluate quantization.
|
| 2 |
+
|
| 3 |
+
Loads a PyTorch checkpoint from shd_train.py, quantizes weights to int16,
|
| 4 |
+
and evaluates accuracy with quantized weights. Also builds an SDK Network
|
| 5 |
+
for deployment to the FPGA via CUBA neurons.
|
| 6 |
+
|
| 7 |
+
Supports both LIF and adLIF checkpoints. For adLIF, adaptation parameters
|
| 8 |
+
(rho, beta_a) are training-only; only alpha (membrane decay) deploys as decay_v.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python shd_deploy.py --checkpoint shd_model.pt --data-dir data/shd
|
| 12 |
+
python shd_deploy.py --checkpoint shd_adlif_model.pt --neuron-type adlif
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import argparse
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch.utils.data import DataLoader
|
| 22 |
+
|
| 23 |
+
# Add SDK and benchmarks to path
|
| 24 |
+
_SDK_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
|
| 25 |
+
if _SDK_DIR not in sys.path:
|
| 26 |
+
sys.path.insert(0, _SDK_DIR)
|
| 27 |
+
sys.path.insert(0, os.path.dirname(__file__))
|
| 28 |
+
|
| 29 |
+
from shd_loader import SHDDataset, collate_fn, N_CHANNELS, N_CLASSES
|
| 30 |
+
from shd_train import SHDSNN
|
| 31 |
+
|
| 32 |
+
from neurocore import Network
|
| 33 |
+
from neurocore.constants import WEIGHT_MIN, WEIGHT_MAX
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def quantize_weights(w_float, threshold_float, threshold_hw=1000):
|
| 37 |
+
"""Quantize float weight matrix to int16 for hardware deployment.
|
| 38 |
+
|
| 39 |
+
Maps float weights so hardware dynamics match training dynamics:
|
| 40 |
+
weight_hw = round(w_float * threshold_hw / threshold_float)
|
| 41 |
+
clamped to [WEIGHT_MIN, WEIGHT_MAX] = [-32768, 32767]
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
w_float: (out, in) float32 weight matrix from nn.Linear
|
| 45 |
+
threshold_float: threshold used in training (e.g. 1.0)
|
| 46 |
+
threshold_hw: hardware threshold (default 1000)
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
w_int: (in, out) int32 weight matrix (transposed for src->tgt convention)
|
| 50 |
+
"""
|
| 51 |
+
scale = threshold_hw / threshold_float
|
| 52 |
+
w_scaled = w_float * scale
|
| 53 |
+
w_int = np.clip(np.round(w_scaled), WEIGHT_MIN, WEIGHT_MAX).astype(np.int32)
|
| 54 |
+
# nn.Linear stores (out, in), SDK wants (src, tgt) = (in, out)
|
| 55 |
+
return w_int.T
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def detect_neuron_type(checkpoint):
|
| 59 |
+
"""Auto-detect neuron type from checkpoint state dict keys."""
|
| 60 |
+
state = checkpoint['model_state_dict']
|
| 61 |
+
if 'lif1.alpha_raw' in state:
|
| 62 |
+
return 'adlif'
|
| 63 |
+
return 'lif'
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def compute_hardware_params(checkpoint, threshold_hw=1000, neuron_type=None):
|
| 67 |
+
"""Compute hardware neuron parameters from trained model.
|
| 68 |
+
|
| 69 |
+
Maps membrane decay to CUBA neuron decay_v:
|
| 70 |
+
decay_v = round(decay * 4096) (12-bit fractional)
|
| 71 |
+
|
| 72 |
+
For LIF: decay = beta (from lif1.beta_raw)
|
| 73 |
+
For adLIF: decay = alpha (from lif1.alpha_raw)
|
| 74 |
+
adLIF adaptation params (rho, beta_a) are training-only.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
dict with hardware parameters for each layer
|
| 78 |
+
"""
|
| 79 |
+
state = checkpoint['model_state_dict']
|
| 80 |
+
if neuron_type is None:
|
| 81 |
+
neuron_type = detect_neuron_type(checkpoint)
|
| 82 |
+
|
| 83 |
+
params = {'neuron_type': neuron_type}
|
| 84 |
+
|
| 85 |
+
if neuron_type == 'adlif':
|
| 86 |
+
# Hidden layer: alpha is membrane decay
|
| 87 |
+
alpha_raw = state.get('lif1.alpha_raw', None)
|
| 88 |
+
if alpha_raw is not None:
|
| 89 |
+
alpha = torch.sigmoid(alpha_raw).cpu().numpy()
|
| 90 |
+
params['hidden_alpha_mean'] = float(alpha.mean())
|
| 91 |
+
params['hidden_alpha_std'] = float(alpha.std())
|
| 92 |
+
params['hidden_decay_v'] = int(round(alpha.mean() * 4096))
|
| 93 |
+
# For backward compat with build_sdk_network
|
| 94 |
+
params['hidden_beta_mean'] = float(alpha.mean())
|
| 95 |
+
|
| 96 |
+
# Log training-only adaptation params
|
| 97 |
+
rho_raw = state.get('lif1.rho_raw', None)
|
| 98 |
+
if rho_raw is not None:
|
| 99 |
+
rho = torch.sigmoid(rho_raw).cpu().numpy()
|
| 100 |
+
params['hidden_rho_mean'] = float(rho.mean())
|
| 101 |
+
params['hidden_rho_note'] = 'training-only (not deployed)'
|
| 102 |
+
|
| 103 |
+
beta_a_raw = state.get('lif1.beta_a_raw', None)
|
| 104 |
+
if beta_a_raw is not None:
|
| 105 |
+
import torch.nn.functional as F_
|
| 106 |
+
beta_a = F_.softplus(beta_a_raw).cpu().numpy()
|
| 107 |
+
params['hidden_beta_a_mean'] = float(beta_a.mean())
|
| 108 |
+
params['hidden_beta_a_note'] = 'training-only (not deployed)'
|
| 109 |
+
else:
|
| 110 |
+
# LIF: beta is membrane decay
|
| 111 |
+
beta_hid_raw = state.get('lif1.beta_raw', None)
|
| 112 |
+
if beta_hid_raw is not None:
|
| 113 |
+
beta_hid = torch.sigmoid(beta_hid_raw).cpu().numpy()
|
| 114 |
+
params['hidden_beta_mean'] = float(beta_hid.mean())
|
| 115 |
+
params['hidden_beta_std'] = float(beta_hid.std())
|
| 116 |
+
params['hidden_decay_v'] = int(round(beta_hid.mean() * 4096))
|
| 117 |
+
|
| 118 |
+
# Output layer is always standard LIF
|
| 119 |
+
beta_out_raw = state.get('lif2.beta_raw', None)
|
| 120 |
+
if beta_out_raw is not None:
|
| 121 |
+
beta_out = torch.sigmoid(beta_out_raw).cpu().numpy()
|
| 122 |
+
params['output_beta_mean'] = float(beta_out.mean())
|
| 123 |
+
params['output_beta_std'] = float(beta_out.std())
|
| 124 |
+
params['output_decay_v'] = int(round(beta_out.mean() * 4096))
|
| 125 |
+
|
| 126 |
+
params['threshold_hw'] = threshold_hw
|
| 127 |
+
return params
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def build_sdk_network(checkpoint, threshold_hw=1000):
|
| 131 |
+
"""Build SDK Network from a trained PyTorch checkpoint.
|
| 132 |
+
|
| 133 |
+
Uses subtractive leak as approximation for multiplicative decay.
|
| 134 |
+
True hardware deployment would use CUBA mode with decay_v.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
net: Network ready for deploy()
|
| 138 |
+
n_hidden: hidden layer size (for reporting)
|
| 139 |
+
"""
|
| 140 |
+
args = checkpoint['args']
|
| 141 |
+
threshold_float = args['threshold']
|
| 142 |
+
n_hidden = args['hidden']
|
| 143 |
+
|
| 144 |
+
state = checkpoint['model_state_dict']
|
| 145 |
+
w_fc1 = state['fc1.weight'].cpu().numpy()
|
| 146 |
+
w_fc2 = state['fc2.weight'].cpu().numpy()
|
| 147 |
+
w_rec = state['fc_rec.weight'].cpu().numpy()
|
| 148 |
+
|
| 149 |
+
# Quantize
|
| 150 |
+
wm_fc1 = quantize_weights(w_fc1, threshold_float, threshold_hw)
|
| 151 |
+
wm_fc2 = quantize_weights(w_fc2, threshold_float, threshold_hw)
|
| 152 |
+
wm_rec = quantize_weights(w_rec, threshold_float, threshold_hw)
|
| 153 |
+
|
| 154 |
+
# Approximate decay as subtractive leak (for SDK Simulator compatibility)
|
| 155 |
+
hw = compute_hardware_params(checkpoint, threshold_hw)
|
| 156 |
+
leak_hid = max(1, int(round((1 - hw.get('hidden_beta_mean', 0.95)) * threshold_hw)))
|
| 157 |
+
leak_out = max(1, int(round((1 - hw.get('output_beta_mean', 0.9)) * threshold_hw)))
|
| 158 |
+
|
| 159 |
+
# Build network
|
| 160 |
+
net = Network()
|
| 161 |
+
inp = net.population(N_CHANNELS,
|
| 162 |
+
params={'threshold': 65535, 'leak': 0, 'refrac': 0},
|
| 163 |
+
label="input")
|
| 164 |
+
hid = net.population(n_hidden,
|
| 165 |
+
params={'threshold': threshold_hw, 'leak': leak_hid, 'refrac': 0},
|
| 166 |
+
label="hidden")
|
| 167 |
+
out = net.population(N_CLASSES,
|
| 168 |
+
params={'threshold': threshold_hw, 'leak': leak_out, 'refrac': 0},
|
| 169 |
+
label="output")
|
| 170 |
+
|
| 171 |
+
net.connect(inp, hid, weight_matrix=wm_fc1)
|
| 172 |
+
net.connect(hid, out, weight_matrix=wm_fc2)
|
| 173 |
+
net.connect(hid, hid, weight_matrix=wm_rec)
|
| 174 |
+
|
| 175 |
+
# Report stats
|
| 176 |
+
nonzero_fc1 = np.count_nonzero(wm_fc1)
|
| 177 |
+
nonzero_fc2 = np.count_nonzero(wm_fc2)
|
| 178 |
+
nonzero_rec = np.count_nonzero(wm_rec)
|
| 179 |
+
total_conn = nonzero_fc1 + nonzero_fc2 + nonzero_rec
|
| 180 |
+
print(f"Quantized weights (threshold_hw={threshold_hw}):")
|
| 181 |
+
print(f" fc1: {wm_fc1.shape}, {nonzero_fc1:,} nonzero, "
|
| 182 |
+
f"range [{wm_fc1.min()}, {wm_fc1.max()}]")
|
| 183 |
+
print(f" fc2: {wm_fc2.shape}, {nonzero_fc2:,} nonzero, "
|
| 184 |
+
f"range [{wm_fc2.min()}, {wm_fc2.max()}]")
|
| 185 |
+
print(f" rec: {wm_rec.shape}, {nonzero_rec:,} nonzero, "
|
| 186 |
+
f"range [{wm_rec.min()}, {wm_rec.max()}]")
|
| 187 |
+
print(f" Total connections: {total_conn:,}")
|
| 188 |
+
if 'hidden_decay_v' in hw:
|
| 189 |
+
print(f" Hardware decay_v (hidden): {hw['hidden_decay_v']} "
|
| 190 |
+
f"(beta={hw['hidden_beta_mean']:.4f})")
|
| 191 |
+
if 'output_decay_v' in hw:
|
| 192 |
+
print(f" Hardware decay_v (output): {hw['output_decay_v']} "
|
| 193 |
+
f"(beta={hw['output_beta_mean']:.4f})")
|
| 194 |
+
|
| 195 |
+
return net, n_hidden
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def run_pytorch_quantized_inference(checkpoint, test_ds, device='cpu',
|
| 199 |
+
neuron_type=None):
|
| 200 |
+
"""Run inference with quantized weights in PyTorch (for comparison).
|
| 201 |
+
|
| 202 |
+
Loads the model, replaces float weights with quantized int versions
|
| 203 |
+
(converted back to float), and runs normal forward pass.
|
| 204 |
+
"""
|
| 205 |
+
args = checkpoint['args']
|
| 206 |
+
threshold_float = args['threshold']
|
| 207 |
+
threshold_hw = 1000
|
| 208 |
+
if neuron_type is None:
|
| 209 |
+
neuron_type = args.get('neuron_type', detect_neuron_type(checkpoint))
|
| 210 |
+
|
| 211 |
+
model = SHDSNN(
|
| 212 |
+
n_hidden=args['hidden'],
|
| 213 |
+
threshold=args['threshold'],
|
| 214 |
+
beta_hidden=args.get('beta_hidden', 0.95),
|
| 215 |
+
beta_out=args.get('beta_out', 0.9),
|
| 216 |
+
dropout=0.0, # no dropout at inference
|
| 217 |
+
neuron_type=neuron_type,
|
| 218 |
+
alpha_init=args.get('alpha_init', 0.90),
|
| 219 |
+
rho_init=args.get('rho_init', 0.85),
|
| 220 |
+
beta_a_init=args.get('beta_a_init', 1.8),
|
| 221 |
+
).to(device)
|
| 222 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 223 |
+
|
| 224 |
+
# Quantize and de-quantize weights to simulate quantization error
|
| 225 |
+
scale = threshold_hw / threshold_float
|
| 226 |
+
skip_keys = ('beta', 'alpha', 'rho', 'threshold_base')
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
for name, param in model.named_parameters():
|
| 229 |
+
if 'weight' in name and not any(k in name for k in skip_keys):
|
| 230 |
+
q = torch.round(param * scale).clamp(WEIGHT_MIN, WEIGHT_MAX) / scale
|
| 231 |
+
param.copy_(q)
|
| 232 |
+
|
| 233 |
+
model.eval()
|
| 234 |
+
loader = DataLoader(test_ds, batch_size=128, shuffle=False,
|
| 235 |
+
collate_fn=collate_fn, num_workers=0)
|
| 236 |
+
|
| 237 |
+
correct = 0
|
| 238 |
+
total = 0
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
for inputs, labels in loader:
|
| 241 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 242 |
+
output = model(inputs)
|
| 243 |
+
correct += (output.argmax(1) == labels).sum().item()
|
| 244 |
+
total += inputs.size(0)
|
| 245 |
+
|
| 246 |
+
acc = correct / total
|
| 247 |
+
print(f" PyTorch quantized accuracy: {correct}/{total} = {acc*100:.1f}%")
|
| 248 |
+
return acc
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def main():
|
| 252 |
+
parser = argparse.ArgumentParser(description="Deploy trained SHD model")
|
| 253 |
+
parser.add_argument("--checkpoint", default="shd_model.pt",
|
| 254 |
+
help="Path to trained model checkpoint")
|
| 255 |
+
parser.add_argument("--data-dir", default="data/shd")
|
| 256 |
+
parser.add_argument("--n-samples", type=int, default=None,
|
| 257 |
+
help="Limit test samples (default: all)")
|
| 258 |
+
parser.add_argument("--threshold-hw", type=int, default=1000)
|
| 259 |
+
parser.add_argument("--dt", type=float, default=4e-3)
|
| 260 |
+
parser.add_argument("--neuron-type", choices=["lif", "adlif"], default=None,
|
| 261 |
+
help="Neuron model (auto-detected from checkpoint if omitted)")
|
| 262 |
+
args = parser.parse_args()
|
| 263 |
+
|
| 264 |
+
print(f"Loading checkpoint: {args.checkpoint}")
|
| 265 |
+
ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
|
| 266 |
+
train_args = ckpt['args']
|
| 267 |
+
|
| 268 |
+
# Auto-detect neuron type if not specified
|
| 269 |
+
neuron_type = args.neuron_type or train_args.get('neuron_type', detect_neuron_type(ckpt))
|
| 270 |
+
print(f" Training accuracy: {ckpt['test_acc']*100:.1f}%")
|
| 271 |
+
print(f" Architecture: {N_CHANNELS}->{train_args['hidden']}->{N_CLASSES} ({neuron_type.upper()})")
|
| 272 |
+
|
| 273 |
+
print("\nLoading test dataset...")
|
| 274 |
+
test_ds = SHDDataset(args.data_dir, "test", dt=args.dt)
|
| 275 |
+
print(f" {len(test_ds)} samples, {test_ds.n_bins} time bins")
|
| 276 |
+
|
| 277 |
+
# 1. Hardware parameter mapping
|
| 278 |
+
print("\n--- Hardware parameter mapping ---")
|
| 279 |
+
hw_params = compute_hardware_params(ckpt, args.threshold_hw, neuron_type)
|
| 280 |
+
for k, v in sorted(hw_params.items()):
|
| 281 |
+
print(f" {k}: {v}")
|
| 282 |
+
|
| 283 |
+
# 2. PyTorch quantized inference (weight quantization impact)
|
| 284 |
+
print("\n--- PyTorch quantized inference ---")
|
| 285 |
+
pytorch_acc = run_pytorch_quantized_inference(ckpt, test_ds,
|
| 286 |
+
neuron_type=neuron_type)
|
| 287 |
+
|
| 288 |
+
# 3. Build SDK network (for reference)
|
| 289 |
+
print("\n--- SDK network summary ---")
|
| 290 |
+
net, n_hidden = build_sdk_network(ckpt, threshold_hw=args.threshold_hw)
|
| 291 |
+
|
| 292 |
+
# Summary
|
| 293 |
+
print("\n=== Results ===")
|
| 294 |
+
print(f" PyTorch float accuracy: {ckpt['test_acc']*100:.1f}%")
|
| 295 |
+
print(f" PyTorch quantized accuracy: {pytorch_acc*100:.1f}%")
|
| 296 |
+
gap = abs(ckpt['test_acc'] - pytorch_acc) * 100
|
| 297 |
+
print(f" Quantization loss: {gap:.1f}%")
|
| 298 |
+
print(f"\n Hardware deployment: CUBA mode (decay_v={hw_params.get('hidden_decay_v', 'N/A')})")
|
| 299 |
+
print(f" Total synapses: {sum(1 for c in net.connections for _ in range(1)):,}")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
if __name__ == "__main__":
|
| 303 |
+
main()
|