CharlesCNorton
Add native forward function for GPU-accelerated evaluation
7af59a2
import torch
def forward(inputs, weights: dict):
"""
CRC-16 CCITT single-step using magnitude-optimal XOR components.
Handles both single inputs (for test building) and batched tensors (for GPU eval).
"""
# Handle list input (from test case building)
if isinstance(inputs, list):
inputs = torch.tensor(inputs, dtype=torch.float32)
device = next(iter(weights.values())).device
inputs = inputs.to(device)
single_input = True
else:
single_input = inputs.dim() == 1
if single_input:
inputs = inputs.unsqueeze(0)
# Ensure weights are on same device
device = inputs.device
def neuron(name, inp):
w = weights[f'{name}.weight'].to(device).flatten()
b = weights[f'{name}.bias'].to(device).item()
return ((inp @ w) + b >= 0).float()
# Layer 1: All neurons that read raw inputs
c1 = neuron('c1', inputs)
c2 = neuron('c2', inputs)
c3 = neuron('c3', inputs)
c4 = neuron('c4', inputs)
c6 = neuron('c6', inputs)
c7 = neuron('c7', inputs)
c8 = neuron('c8', inputs)
c9 = neuron('c9', inputs)
c10 = neuron('c10', inputs)
c11 = neuron('c11', inputs)
c13 = neuron('c13', inputs)
c14 = neuron('c14', inputs)
c15 = neuron('c15', inputs)
c0_h1 = neuron('c0.h1', inputs)
c0_h2 = neuron('c0.h2', inputs)
c5_h1 = neuron('c5.h1', inputs)
c5_h2 = neuron('c5.h2', inputs)
c5_h3 = neuron('c5.h3', inputs)
c12_h1 = neuron('c12.h1', inputs)
c12_h2 = neuron('c12.h2', inputs)
c12_h3 = neuron('c12.h3', inputs)
# Layer 2: XOR outputs
def neuron2(name, h):
w = weights[f'{name}.weight'].to(device).flatten()
b = weights[f'{name}.bias'].to(device).item()
return ((h @ w) + b >= 0).float()
c0_hidden = torch.stack([c0_h1, c0_h2], dim=1)
c0_out = neuron2('c0.out', c0_hidden)
c5_hidden = torch.stack([c5_h1, c5_h2, c5_h3], dim=1)
c5_out = neuron2('c5.out', c5_hidden)
c12_hidden = torch.stack([c12_h1, c12_h2, c12_h3], dim=1)
c12_out = neuron2('c12.out', c12_hidden)
outputs = torch.stack([
c0_out, c1, c2, c3, c4, c5_out, c6, c7,
c8, c9, c10, c11, c12_out, c13, c14, c15
], dim=1)
if single_input:
return outputs.squeeze(0).tolist()
return outputs