|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
def load_model(path='model.safetensors'): |
|
|
return load_file(path) |
|
|
|
|
|
def hamming_distance(a3, a2, a1, a0, b3, b2, b1, b0, weights): |
|
|
"""Compute Hamming distance between two 4-bit values. Returns (d2, d1, d0).""" |
|
|
inp = torch.tensor([float(a3), float(a2), float(a1), float(a0), |
|
|
float(b3), float(b2), float(b1), float(b0)]) |
|
|
|
|
|
|
|
|
diff_hi, diff_lo = [], [] |
|
|
for i in range(4): |
|
|
dh = int((inp @ weights[f'layer1.diff_hi_{i}.weight'].T + weights[f'layer1.diff_hi_{i}.bias'] >= 0).item()) |
|
|
dl = int((inp @ weights[f'layer1.diff_lo_{i}.weight'].T + weights[f'layer1.diff_lo_{i}.bias'] >= 0).item()) |
|
|
diff_hi.append(dh) |
|
|
diff_lo.append(dl) |
|
|
|
|
|
|
|
|
diffs = [] |
|
|
for i in range(4): |
|
|
d_inp = torch.tensor([float(diff_hi[i]), float(diff_lo[i])]) |
|
|
d = int((d_inp @ weights[f'layer2.diff_{i}.weight'].T + weights[f'layer2.diff_{i}.bias'] >= 0).item()) |
|
|
diffs.append(d) |
|
|
|
|
|
|
|
|
diff_tensor = torch.tensor([float(d) for d in diffs]) |
|
|
ge = [] |
|
|
for k in range(1, 5): |
|
|
g = int((diff_tensor @ weights[f'layer3.ge{k}.weight'].T + weights[f'layer3.ge{k}.bias'] >= 0).item()) |
|
|
ge.append(g) |
|
|
|
|
|
|
|
|
ge_tensor = torch.tensor([float(g) for g in ge]) |
|
|
d2 = int((ge_tensor @ weights['layer4.d2.weight'].T + weights['layer4.d2.bias'] >= 0).item()) |
|
|
d1 = int((ge_tensor @ weights['layer4.d1.weight'].T + weights['layer4.d1.bias'] >= 0).item()) |
|
|
d0_part1 = int((ge_tensor @ weights['layer4.d0_part1.weight'].T + weights['layer4.d0_part1.bias'] >= 0).item()) |
|
|
d0_part2 = int((ge_tensor @ weights['layer4.d0_part2.weight'].T + weights['layer4.d0_part2.bias'] >= 0).item()) |
|
|
|
|
|
|
|
|
d0_inp = torch.tensor([float(d0_part1), float(d0_part2)]) |
|
|
d0 = int((d0_inp @ weights['layer5.d0.weight'].T + weights['layer5.d0.bias'] >= 0).item()) |
|
|
|
|
|
return d2, d1, d0 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
w = load_model() |
|
|
print('Hamming Distance 4-bit examples:') |
|
|
test_cases = [(0b0000, 0b0000), (0b1111, 0b0000), (0b1010, 0b0101), (0b1100, 0b1010)] |
|
|
for a, b in test_cases: |
|
|
a3, a2, a1, a0 = (a >> 3) & 1, (a >> 2) & 1, (a >> 1) & 1, a & 1 |
|
|
b3, b2, b1, b0 = (b >> 3) & 1, (b >> 2) & 1, (b >> 1) & 1, b & 1 |
|
|
d2, d1, d0 = hamming_distance(a3, a2, a1, a0, b3, b2, b1, b0, w) |
|
|
dist = 4*d2 + 2*d1 + d0 |
|
|
print(f' HD({a:04b}, {b:04b}) = {dist}') |
|
|
|