|
|
import torch |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
def load_model(path='model.safetensors'): |
|
|
return load_file(path) |
|
|
|
|
|
def gray2binary(g3, g2, g1, g0, weights): |
|
|
"""Convert 4-bit Gray code to binary.""" |
|
|
inp = torch.tensor([float(g3), float(g2), float(g1), float(g0)]) |
|
|
|
|
|
|
|
|
b3 = int((inp @ weights['b3.weight'].T + weights['b3.bias'] >= 0).item()) |
|
|
|
|
|
|
|
|
x1_or = int((inp @ weights['x1_or.weight'].T + weights['x1_or.bias'] >= 0).item()) |
|
|
x1_nand = int((inp @ weights['x1_nand.weight'].T + weights['x1_nand.bias'] >= 0).item()) |
|
|
x1_vec = torch.tensor([float(x1_or), float(x1_nand)]) |
|
|
b2 = int((x1_vec @ weights['b2.weight'].T + weights['b2.bias'] >= 0).item()) |
|
|
x1 = b2 |
|
|
|
|
|
|
|
|
x2_inp = torch.tensor([float(x1), float(g1)]) |
|
|
x2_or = int((x2_inp @ weights['x2_or.weight'].T + weights['x2_or.bias'] >= 0).item()) |
|
|
x2_nand = int((x2_inp @ weights['x2_nand.weight'].T + weights['x2_nand.bias'] >= 0).item()) |
|
|
x2_vec = torch.tensor([float(x2_or), float(x2_nand)]) |
|
|
b1 = int((x2_vec @ weights['b1.weight'].T + weights['b1.bias'] >= 0).item()) |
|
|
x2 = b1 |
|
|
|
|
|
|
|
|
x3_inp = torch.tensor([float(x2), float(g0)]) |
|
|
x3_or = int((x3_inp @ weights['x3_or.weight'].T + weights['x3_or.bias'] >= 0).item()) |
|
|
x3_nand = int((x3_inp @ weights['x3_nand.weight'].T + weights['x3_nand.bias'] >= 0).item()) |
|
|
x3_vec = torch.tensor([float(x3_or), float(x3_nand)]) |
|
|
b0 = int((x3_vec @ weights['b0.weight'].T + weights['b0.bias'] >= 0).item()) |
|
|
|
|
|
return b3, b2, b1, b0 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
w = load_model() |
|
|
print('Gray to Binary conversion:') |
|
|
print('Binary -> Gray -> Binary') |
|
|
for i in range(16): |
|
|
gray = i ^ (i >> 1) |
|
|
g3, g2, g1, g0 = (gray >> 3) & 1, (gray >> 2) & 1, (gray >> 1) & 1, gray & 1 |
|
|
b3, b2, b1, b0 = gray2binary(g3, g2, g1, g0, w) |
|
|
result = b3 * 8 + b2 * 4 + b1 * 2 + b0 |
|
|
print(f' {i:2d} -> {g3}{g2}{g1}{g0} -> {b3}{b2}{b1}{b0} = {result:2d}') |
|
|
|