shd-snn-benchmark / shd_deploy.py
mrwabbit's picture
Upload shd_deploy.py with huggingface_hub
9908537 verified
"""Deploy a trained SHD model to the Neurocore SDK or evaluate quantization.
Loads a PyTorch checkpoint from shd_train.py, quantizes weights to int16,
and evaluates accuracy with quantized weights. Also builds an SDK Network
for deployment to the FPGA via CUBA neurons.
Supports both LIF and adLIF checkpoints. For adLIF, adaptation parameters
(rho, beta_a) are training-only; only alpha (membrane decay) deploys as decay_v.
Usage:
python shd_deploy.py --checkpoint shd_model.pt --data-dir data/shd
python shd_deploy.py --checkpoint shd_adlif_model.pt --neuron-type adlif
"""
import os
import sys
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
# Add SDK and benchmarks to path
_SDK_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
if _SDK_DIR not in sys.path:
sys.path.insert(0, _SDK_DIR)
sys.path.insert(0, os.path.dirname(__file__))
from shd_loader import SHDDataset, collate_fn, N_CHANNELS, N_CLASSES
from shd_train import SHDSNN
from neurocore import Network
from neurocore.constants import WEIGHT_MIN, WEIGHT_MAX
def quantize_weights(w_float, threshold_float, threshold_hw=1000):
"""Quantize float weight matrix to int16 for hardware deployment.
Maps float weights so hardware dynamics match training dynamics:
weight_hw = round(w_float * threshold_hw / threshold_float)
clamped to [WEIGHT_MIN, WEIGHT_MAX] = [-32768, 32767]
Args:
w_float: (out, in) float32 weight matrix from nn.Linear
threshold_float: threshold used in training (e.g. 1.0)
threshold_hw: hardware threshold (default 1000)
Returns:
w_int: (in, out) int32 weight matrix (transposed for src->tgt convention)
"""
scale = threshold_hw / threshold_float
w_scaled = w_float * scale
w_int = np.clip(np.round(w_scaled), WEIGHT_MIN, WEIGHT_MAX).astype(np.int32)
# nn.Linear stores (out, in), SDK wants (src, tgt) = (in, out)
return w_int.T
def detect_neuron_type(checkpoint):
"""Auto-detect neuron type from checkpoint state dict keys."""
state = checkpoint['model_state_dict']
if 'lif1.alpha_raw' in state:
return 'adlif'
return 'lif'
def compute_hardware_params(checkpoint, threshold_hw=1000, neuron_type=None):
"""Compute hardware neuron parameters from trained model.
Maps membrane decay to CUBA neuron decay_v:
decay_v = round(decay * 4096) (12-bit fractional)
For LIF: decay = beta (from lif1.beta_raw)
For adLIF: decay = alpha (from lif1.alpha_raw)
adLIF adaptation params (rho, beta_a) are training-only.
Returns:
dict with hardware parameters for each layer
"""
state = checkpoint['model_state_dict']
if neuron_type is None:
neuron_type = detect_neuron_type(checkpoint)
params = {'neuron_type': neuron_type}
if neuron_type == 'adlif':
# Hidden layer: alpha is membrane decay
alpha_raw = state.get('lif1.alpha_raw', None)
if alpha_raw is not None:
alpha = torch.sigmoid(alpha_raw).cpu().numpy()
params['hidden_alpha_mean'] = float(alpha.mean())
params['hidden_alpha_std'] = float(alpha.std())
params['hidden_decay_v'] = int(round(alpha.mean() * 4096))
# For backward compat with build_sdk_network
params['hidden_beta_mean'] = float(alpha.mean())
# Log training-only adaptation params
rho_raw = state.get('lif1.rho_raw', None)
if rho_raw is not None:
rho = torch.sigmoid(rho_raw).cpu().numpy()
params['hidden_rho_mean'] = float(rho.mean())
params['hidden_rho_note'] = 'training-only (not deployed)'
beta_a_raw = state.get('lif1.beta_a_raw', None)
if beta_a_raw is not None:
import torch.nn.functional as F_
beta_a = F_.softplus(beta_a_raw).cpu().numpy()
params['hidden_beta_a_mean'] = float(beta_a.mean())
params['hidden_beta_a_note'] = 'training-only (not deployed)'
else:
# LIF: beta is membrane decay
beta_hid_raw = state.get('lif1.beta_raw', None)
if beta_hid_raw is not None:
beta_hid = torch.sigmoid(beta_hid_raw).cpu().numpy()
params['hidden_beta_mean'] = float(beta_hid.mean())
params['hidden_beta_std'] = float(beta_hid.std())
params['hidden_decay_v'] = int(round(beta_hid.mean() * 4096))
# Output layer is always standard LIF
beta_out_raw = state.get('lif2.beta_raw', None)
if beta_out_raw is not None:
beta_out = torch.sigmoid(beta_out_raw).cpu().numpy()
params['output_beta_mean'] = float(beta_out.mean())
params['output_beta_std'] = float(beta_out.std())
params['output_decay_v'] = int(round(beta_out.mean() * 4096))
params['threshold_hw'] = threshold_hw
return params
def build_sdk_network(checkpoint, threshold_hw=1000):
"""Build SDK Network from a trained PyTorch checkpoint.
Uses subtractive leak as approximation for multiplicative decay.
True hardware deployment would use CUBA mode with decay_v.
Returns:
net: Network ready for deploy()
n_hidden: hidden layer size (for reporting)
"""
args = checkpoint['args']
threshold_float = args['threshold']
n_hidden = args['hidden']
state = checkpoint['model_state_dict']
w_fc1 = state['fc1.weight'].cpu().numpy()
w_fc2 = state['fc2.weight'].cpu().numpy()
w_rec = state['fc_rec.weight'].cpu().numpy()
# Quantize
wm_fc1 = quantize_weights(w_fc1, threshold_float, threshold_hw)
wm_fc2 = quantize_weights(w_fc2, threshold_float, threshold_hw)
wm_rec = quantize_weights(w_rec, threshold_float, threshold_hw)
# Approximate decay as subtractive leak (for SDK Simulator compatibility)
hw = compute_hardware_params(checkpoint, threshold_hw)
leak_hid = max(1, int(round((1 - hw.get('hidden_beta_mean', 0.95)) * threshold_hw)))
leak_out = max(1, int(round((1 - hw.get('output_beta_mean', 0.9)) * threshold_hw)))
# Build network
net = Network()
inp = net.population(N_CHANNELS,
params={'threshold': 65535, 'leak': 0, 'refrac': 0},
label="input")
hid = net.population(n_hidden,
params={'threshold': threshold_hw, 'leak': leak_hid, 'refrac': 0},
label="hidden")
out = net.population(N_CLASSES,
params={'threshold': threshold_hw, 'leak': leak_out, 'refrac': 0},
label="output")
net.connect(inp, hid, weight_matrix=wm_fc1)
net.connect(hid, out, weight_matrix=wm_fc2)
net.connect(hid, hid, weight_matrix=wm_rec)
# Report stats
nonzero_fc1 = np.count_nonzero(wm_fc1)
nonzero_fc2 = np.count_nonzero(wm_fc2)
nonzero_rec = np.count_nonzero(wm_rec)
total_conn = nonzero_fc1 + nonzero_fc2 + nonzero_rec
print(f"Quantized weights (threshold_hw={threshold_hw}):")
print(f" fc1: {wm_fc1.shape}, {nonzero_fc1:,} nonzero, "
f"range [{wm_fc1.min()}, {wm_fc1.max()}]")
print(f" fc2: {wm_fc2.shape}, {nonzero_fc2:,} nonzero, "
f"range [{wm_fc2.min()}, {wm_fc2.max()}]")
print(f" rec: {wm_rec.shape}, {nonzero_rec:,} nonzero, "
f"range [{wm_rec.min()}, {wm_rec.max()}]")
print(f" Total connections: {total_conn:,}")
if 'hidden_decay_v' in hw:
print(f" Hardware decay_v (hidden): {hw['hidden_decay_v']} "
f"(beta={hw['hidden_beta_mean']:.4f})")
if 'output_decay_v' in hw:
print(f" Hardware decay_v (output): {hw['output_decay_v']} "
f"(beta={hw['output_beta_mean']:.4f})")
return net, n_hidden
def run_pytorch_quantized_inference(checkpoint, test_ds, device='cpu',
neuron_type=None):
"""Run inference with quantized weights in PyTorch (for comparison).
Loads the model, replaces float weights with quantized int versions
(converted back to float), and runs normal forward pass.
"""
args = checkpoint['args']
threshold_float = args['threshold']
threshold_hw = 1000
if neuron_type is None:
neuron_type = args.get('neuron_type', detect_neuron_type(checkpoint))
model = SHDSNN(
n_hidden=args['hidden'],
threshold=args['threshold'],
beta_hidden=args.get('beta_hidden', 0.95),
beta_out=args.get('beta_out', 0.9),
dropout=0.0, # no dropout at inference
neuron_type=neuron_type,
alpha_init=args.get('alpha_init', 0.90),
rho_init=args.get('rho_init', 0.85),
beta_a_init=args.get('beta_a_init', 1.8),
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
# Quantize and de-quantize weights to simulate quantization error
scale = threshold_hw / threshold_float
skip_keys = ('beta', 'alpha', 'rho', 'threshold_base')
with torch.no_grad():
for name, param in model.named_parameters():
if 'weight' in name and not any(k in name for k in skip_keys):
q = torch.round(param * scale).clamp(WEIGHT_MIN, WEIGHT_MAX) / scale
param.copy_(q)
model.eval()
loader = DataLoader(test_ds, batch_size=128, shuffle=False,
collate_fn=collate_fn, num_workers=0)
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in loader:
inputs, labels = inputs.to(device), labels.to(device)
output = model(inputs)
correct += (output.argmax(1) == labels).sum().item()
total += inputs.size(0)
acc = correct / total
print(f" PyTorch quantized accuracy: {correct}/{total} = {acc*100:.1f}%")
return acc
def main():
parser = argparse.ArgumentParser(description="Deploy trained SHD model")
parser.add_argument("--checkpoint", default="shd_model.pt",
help="Path to trained model checkpoint")
parser.add_argument("--data-dir", default="data/shd")
parser.add_argument("--n-samples", type=int, default=None,
help="Limit test samples (default: all)")
parser.add_argument("--threshold-hw", type=int, default=1000)
parser.add_argument("--dt", type=float, default=4e-3)
parser.add_argument("--neuron-type", choices=["lif", "adlif"], default=None,
help="Neuron model (auto-detected from checkpoint if omitted)")
args = parser.parse_args()
print(f"Loading checkpoint: {args.checkpoint}")
ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
train_args = ckpt['args']
# Auto-detect neuron type if not specified
neuron_type = args.neuron_type or train_args.get('neuron_type', detect_neuron_type(ckpt))
print(f" Training accuracy: {ckpt['test_acc']*100:.1f}%")
print(f" Architecture: {N_CHANNELS}->{train_args['hidden']}->{N_CLASSES} ({neuron_type.upper()})")
print("\nLoading test dataset...")
test_ds = SHDDataset(args.data_dir, "test", dt=args.dt)
print(f" {len(test_ds)} samples, {test_ds.n_bins} time bins")
# 1. Hardware parameter mapping
print("\n--- Hardware parameter mapping ---")
hw_params = compute_hardware_params(ckpt, args.threshold_hw, neuron_type)
for k, v in sorted(hw_params.items()):
print(f" {k}: {v}")
# 2. PyTorch quantized inference (weight quantization impact)
print("\n--- PyTorch quantized inference ---")
pytorch_acc = run_pytorch_quantized_inference(ckpt, test_ds,
neuron_type=neuron_type)
# 3. Build SDK network (for reference)
print("\n--- SDK network summary ---")
net, n_hidden = build_sdk_network(ckpt, threshold_hw=args.threshold_hw)
# Summary
print("\n=== Results ===")
print(f" PyTorch float accuracy: {ckpt['test_acc']*100:.1f}%")
print(f" PyTorch quantized accuracy: {pytorch_acc*100:.1f}%")
gap = abs(ckpt['test_acc'] - pytorch_acc) * 100
print(f" Quantization loss: {gap:.1f}%")
print(f"\n Hardware deployment: CUBA mode (decay_v={hw_params.get('hidden_decay_v', 'N/A')})")
print(f" Total synapses: {sum(1 for c in net.connections for _ in range(1)):,}")
if __name__ == "__main__":
main()