File size: 835 Bytes
a9b8c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def reverse8(a7, a6, a5, a4, a3, a2, a1, a0, weights):
    """8-bit bit reversal."""
    inp = torch.tensor([float(a7), float(a6), float(a5), float(a4),
                        float(a3), float(a2), float(a1), float(a0)])
    outputs = []
    for i in range(8):
        y = int((inp @ weights[f'y{i}.weight'].T + weights[f'y{i}.bias'] >= 0).item())
        outputs.append(y)
    return outputs

if __name__ == '__main__':
    w = load_model()
    print('reverse8 examples:')
    test_cases = [
        (1, 0, 0, 0, 0, 0, 0, 0),
        (0, 0, 0, 0, 0, 0, 0, 1),
        (1, 0, 1, 0, 0, 1, 0, 1),
    ]
    for bits in test_cases:
        result = reverse8(*bits, w)
        print(f'  {list(bits)} -> {result}')