import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def hamming_distance(a3, a2, a1, a0, b3, b2, b1, b0, weights): """Compute Hamming distance between two 4-bit values. Returns (d2, d1, d0).""" inp = torch.tensor([float(a3), float(a2), float(a1), float(a0), float(b3), float(b2), float(b1), float(b0)]) # Layer 1: diff_hi and diff_lo for each bit diff_hi, diff_lo = [], [] for i in range(4): dh = int((inp @ weights[f'layer1.diff_hi_{i}.weight'].T + weights[f'layer1.diff_hi_{i}.bias'] >= 0).item()) dl = int((inp @ weights[f'layer1.diff_lo_{i}.weight'].T + weights[f'layer1.diff_lo_{i}.bias'] >= 0).item()) diff_hi.append(dh) diff_lo.append(dl) # Layer 2: XOR = OR(diff_hi, diff_lo) diffs = [] for i in range(4): d_inp = torch.tensor([float(diff_hi[i]), float(diff_lo[i])]) d = int((d_inp @ weights[f'layer2.diff_{i}.weight'].T + weights[f'layer2.diff_{i}.bias'] >= 0).item()) diffs.append(d) # Layer 3: threshold detectors diff_tensor = torch.tensor([float(d) for d in diffs]) ge = [] for k in range(1, 5): g = int((diff_tensor @ weights[f'layer3.ge{k}.weight'].T + weights[f'layer3.ge{k}.bias'] >= 0).item()) ge.append(g) # Layer 4: binary encoding ge_tensor = torch.tensor([float(g) for g in ge]) d2 = int((ge_tensor @ weights['layer4.d2.weight'].T + weights['layer4.d2.bias'] >= 0).item()) d1 = int((ge_tensor @ weights['layer4.d1.weight'].T + weights['layer4.d1.bias'] >= 0).item()) d0_part1 = int((ge_tensor @ weights['layer4.d0_part1.weight'].T + weights['layer4.d0_part1.bias'] >= 0).item()) d0_part2 = int((ge_tensor @ weights['layer4.d0_part2.weight'].T + weights['layer4.d0_part2.bias'] >= 0).item()) # Layer 5: d0 = OR(d0_part1, d0_part2) d0_inp = torch.tensor([float(d0_part1), float(d0_part2)]) d0 = int((d0_inp @ weights['layer5.d0.weight'].T + weights['layer5.d0.bias'] >= 0).item()) return d2, d1, d0 if __name__ == '__main__': w = load_model() print('Hamming Distance 4-bit examples:') test_cases = [(0b0000, 0b0000), (0b1111, 0b0000), (0b1010, 0b0101), (0b1100, 0b1010)] for a, b in test_cases: a3, a2, a1, a0 = (a >> 3) & 1, (a >> 2) & 1, (a >> 1) & 1, a & 1 b3, b2, b1, b0 = (b >> 3) & 1, (b >> 2) & 1, (b >> 1) & 1, b & 1 d2, d1, d0 = hamming_distance(a3, a2, a1, a0, b3, b2, b1, b0, w) dist = 4*d2 + 2*d1 + d0 print(f' HD({a:04b}, {b:04b}) = {dist}')