File size: 12,626 Bytes
9908537 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 | """Deploy a trained SHD model to the Neurocore SDK or evaluate quantization.
Loads a PyTorch checkpoint from shd_train.py, quantizes weights to int16,
and evaluates accuracy with quantized weights. Also builds an SDK Network
for deployment to the FPGA via CUBA neurons.
Supports both LIF and adLIF checkpoints. For adLIF, adaptation parameters
(rho, beta_a) are training-only; only alpha (membrane decay) deploys as decay_v.
Usage:
python shd_deploy.py --checkpoint shd_model.pt --data-dir data/shd
python shd_deploy.py --checkpoint shd_adlif_model.pt --neuron-type adlif
"""
import os
import sys
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
# Add SDK and benchmarks to path
_SDK_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
if _SDK_DIR not in sys.path:
sys.path.insert(0, _SDK_DIR)
sys.path.insert(0, os.path.dirname(__file__))
from shd_loader import SHDDataset, collate_fn, N_CHANNELS, N_CLASSES
from shd_train import SHDSNN
from neurocore import Network
from neurocore.constants import WEIGHT_MIN, WEIGHT_MAX
def quantize_weights(w_float, threshold_float, threshold_hw=1000):
"""Quantize float weight matrix to int16 for hardware deployment.
Maps float weights so hardware dynamics match training dynamics:
weight_hw = round(w_float * threshold_hw / threshold_float)
clamped to [WEIGHT_MIN, WEIGHT_MAX] = [-32768, 32767]
Args:
w_float: (out, in) float32 weight matrix from nn.Linear
threshold_float: threshold used in training (e.g. 1.0)
threshold_hw: hardware threshold (default 1000)
Returns:
w_int: (in, out) int32 weight matrix (transposed for src->tgt convention)
"""
scale = threshold_hw / threshold_float
w_scaled = w_float * scale
w_int = np.clip(np.round(w_scaled), WEIGHT_MIN, WEIGHT_MAX).astype(np.int32)
# nn.Linear stores (out, in), SDK wants (src, tgt) = (in, out)
return w_int.T
def detect_neuron_type(checkpoint):
"""Auto-detect neuron type from checkpoint state dict keys."""
state = checkpoint['model_state_dict']
if 'lif1.alpha_raw' in state:
return 'adlif'
return 'lif'
def compute_hardware_params(checkpoint, threshold_hw=1000, neuron_type=None):
"""Compute hardware neuron parameters from trained model.
Maps membrane decay to CUBA neuron decay_v:
decay_v = round(decay * 4096) (12-bit fractional)
For LIF: decay = beta (from lif1.beta_raw)
For adLIF: decay = alpha (from lif1.alpha_raw)
adLIF adaptation params (rho, beta_a) are training-only.
Returns:
dict with hardware parameters for each layer
"""
state = checkpoint['model_state_dict']
if neuron_type is None:
neuron_type = detect_neuron_type(checkpoint)
params = {'neuron_type': neuron_type}
if neuron_type == 'adlif':
# Hidden layer: alpha is membrane decay
alpha_raw = state.get('lif1.alpha_raw', None)
if alpha_raw is not None:
alpha = torch.sigmoid(alpha_raw).cpu().numpy()
params['hidden_alpha_mean'] = float(alpha.mean())
params['hidden_alpha_std'] = float(alpha.std())
params['hidden_decay_v'] = int(round(alpha.mean() * 4096))
# For backward compat with build_sdk_network
params['hidden_beta_mean'] = float(alpha.mean())
# Log training-only adaptation params
rho_raw = state.get('lif1.rho_raw', None)
if rho_raw is not None:
rho = torch.sigmoid(rho_raw).cpu().numpy()
params['hidden_rho_mean'] = float(rho.mean())
params['hidden_rho_note'] = 'training-only (not deployed)'
beta_a_raw = state.get('lif1.beta_a_raw', None)
if beta_a_raw is not None:
import torch.nn.functional as F_
beta_a = F_.softplus(beta_a_raw).cpu().numpy()
params['hidden_beta_a_mean'] = float(beta_a.mean())
params['hidden_beta_a_note'] = 'training-only (not deployed)'
else:
# LIF: beta is membrane decay
beta_hid_raw = state.get('lif1.beta_raw', None)
if beta_hid_raw is not None:
beta_hid = torch.sigmoid(beta_hid_raw).cpu().numpy()
params['hidden_beta_mean'] = float(beta_hid.mean())
params['hidden_beta_std'] = float(beta_hid.std())
params['hidden_decay_v'] = int(round(beta_hid.mean() * 4096))
# Output layer is always standard LIF
beta_out_raw = state.get('lif2.beta_raw', None)
if beta_out_raw is not None:
beta_out = torch.sigmoid(beta_out_raw).cpu().numpy()
params['output_beta_mean'] = float(beta_out.mean())
params['output_beta_std'] = float(beta_out.std())
params['output_decay_v'] = int(round(beta_out.mean() * 4096))
params['threshold_hw'] = threshold_hw
return params
def build_sdk_network(checkpoint, threshold_hw=1000):
"""Build SDK Network from a trained PyTorch checkpoint.
Uses subtractive leak as approximation for multiplicative decay.
True hardware deployment would use CUBA mode with decay_v.
Returns:
net: Network ready for deploy()
n_hidden: hidden layer size (for reporting)
"""
args = checkpoint['args']
threshold_float = args['threshold']
n_hidden = args['hidden']
state = checkpoint['model_state_dict']
w_fc1 = state['fc1.weight'].cpu().numpy()
w_fc2 = state['fc2.weight'].cpu().numpy()
w_rec = state['fc_rec.weight'].cpu().numpy()
# Quantize
wm_fc1 = quantize_weights(w_fc1, threshold_float, threshold_hw)
wm_fc2 = quantize_weights(w_fc2, threshold_float, threshold_hw)
wm_rec = quantize_weights(w_rec, threshold_float, threshold_hw)
# Approximate decay as subtractive leak (for SDK Simulator compatibility)
hw = compute_hardware_params(checkpoint, threshold_hw)
leak_hid = max(1, int(round((1 - hw.get('hidden_beta_mean', 0.95)) * threshold_hw)))
leak_out = max(1, int(round((1 - hw.get('output_beta_mean', 0.9)) * threshold_hw)))
# Build network
net = Network()
inp = net.population(N_CHANNELS,
params={'threshold': 65535, 'leak': 0, 'refrac': 0},
label="input")
hid = net.population(n_hidden,
params={'threshold': threshold_hw, 'leak': leak_hid, 'refrac': 0},
label="hidden")
out = net.population(N_CLASSES,
params={'threshold': threshold_hw, 'leak': leak_out, 'refrac': 0},
label="output")
net.connect(inp, hid, weight_matrix=wm_fc1)
net.connect(hid, out, weight_matrix=wm_fc2)
net.connect(hid, hid, weight_matrix=wm_rec)
# Report stats
nonzero_fc1 = np.count_nonzero(wm_fc1)
nonzero_fc2 = np.count_nonzero(wm_fc2)
nonzero_rec = np.count_nonzero(wm_rec)
total_conn = nonzero_fc1 + nonzero_fc2 + nonzero_rec
print(f"Quantized weights (threshold_hw={threshold_hw}):")
print(f" fc1: {wm_fc1.shape}, {nonzero_fc1:,} nonzero, "
f"range [{wm_fc1.min()}, {wm_fc1.max()}]")
print(f" fc2: {wm_fc2.shape}, {nonzero_fc2:,} nonzero, "
f"range [{wm_fc2.min()}, {wm_fc2.max()}]")
print(f" rec: {wm_rec.shape}, {nonzero_rec:,} nonzero, "
f"range [{wm_rec.min()}, {wm_rec.max()}]")
print(f" Total connections: {total_conn:,}")
if 'hidden_decay_v' in hw:
print(f" Hardware decay_v (hidden): {hw['hidden_decay_v']} "
f"(beta={hw['hidden_beta_mean']:.4f})")
if 'output_decay_v' in hw:
print(f" Hardware decay_v (output): {hw['output_decay_v']} "
f"(beta={hw['output_beta_mean']:.4f})")
return net, n_hidden
def run_pytorch_quantized_inference(checkpoint, test_ds, device='cpu',
neuron_type=None):
"""Run inference with quantized weights in PyTorch (for comparison).
Loads the model, replaces float weights with quantized int versions
(converted back to float), and runs normal forward pass.
"""
args = checkpoint['args']
threshold_float = args['threshold']
threshold_hw = 1000
if neuron_type is None:
neuron_type = args.get('neuron_type', detect_neuron_type(checkpoint))
model = SHDSNN(
n_hidden=args['hidden'],
threshold=args['threshold'],
beta_hidden=args.get('beta_hidden', 0.95),
beta_out=args.get('beta_out', 0.9),
dropout=0.0, # no dropout at inference
neuron_type=neuron_type,
alpha_init=args.get('alpha_init', 0.90),
rho_init=args.get('rho_init', 0.85),
beta_a_init=args.get('beta_a_init', 1.8),
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
# Quantize and de-quantize weights to simulate quantization error
scale = threshold_hw / threshold_float
skip_keys = ('beta', 'alpha', 'rho', 'threshold_base')
with torch.no_grad():
for name, param in model.named_parameters():
if 'weight' in name and not any(k in name for k in skip_keys):
q = torch.round(param * scale).clamp(WEIGHT_MIN, WEIGHT_MAX) / scale
param.copy_(q)
model.eval()
loader = DataLoader(test_ds, batch_size=128, shuffle=False,
collate_fn=collate_fn, num_workers=0)
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
output = model(inputs)
correct += (output.argmax(1) == labels).sum().item()
total += inputs.size(0)
acc = correct / total
print(f" PyTorch quantized accuracy: {correct}/{total} = {acc*100:.1f}%")
return acc
def main():
parser = argparse.ArgumentParser(description="Deploy trained SHD model")
parser.add_argument("--checkpoint", default="shd_model.pt",
help="Path to trained model checkpoint")
parser.add_argument("--data-dir", default="data/shd")
parser.add_argument("--n-samples", type=int, default=None,
help="Limit test samples (default: all)")
parser.add_argument("--threshold-hw", type=int, default=1000)
parser.add_argument("--dt", type=float, default=4e-3)
parser.add_argument("--neuron-type", choices=["lif", "adlif"], default=None,
help="Neuron model (auto-detected from checkpoint if omitted)")
args = parser.parse_args()
print(f"Loading checkpoint: {args.checkpoint}")
ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
train_args = ckpt['args']
# Auto-detect neuron type if not specified
neuron_type = args.neuron_type or train_args.get('neuron_type', detect_neuron_type(ckpt))
print(f" Training accuracy: {ckpt['test_acc']*100:.1f}%")
print(f" Architecture: {N_CHANNELS}->{train_args['hidden']}->{N_CLASSES} ({neuron_type.upper()})")
print("\nLoading test dataset...")
test_ds = SHDDataset(args.data_dir, "test", dt=args.dt)
print(f" {len(test_ds)} samples, {test_ds.n_bins} time bins")
# 1. Hardware parameter mapping
print("\n--- Hardware parameter mapping ---")
hw_params = compute_hardware_params(ckpt, args.threshold_hw, neuron_type)
for k, v in sorted(hw_params.items()):
print(f" {k}: {v}")
# 2. PyTorch quantized inference (weight quantization impact)
print("\n--- PyTorch quantized inference ---")
pytorch_acc = run_pytorch_quantized_inference(ckpt, test_ds,
neuron_type=neuron_type)
# 3. Build SDK network (for reference)
print("\n--- SDK network summary ---")
net, n_hidden = build_sdk_network(ckpt, threshold_hw=args.threshold_hw)
# Summary
print("\n=== Results ===")
print(f" PyTorch float accuracy: {ckpt['test_acc']*100:.1f}%")
print(f" PyTorch quantized accuracy: {pytorch_acc*100:.1f}%")
gap = abs(ckpt['test_acc'] - pytorch_acc) * 100
print(f" Quantization loss: {gap:.1f}%")
print(f"\n Hardware deployment: CUBA mode (decay_v={hw_params.get('hidden_decay_v', 'N/A')})")
print(f" Total synapses: {sum(1 for c in net.connections for _ in range(1)):,}")
if __name__ == "__main__":
main()
|