File size: 1,986 Bytes
462e438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from safetensors.torch import load_file

def load_model(path='model.safetensors'):
    return load_file(path)

def and_gate(x, y, w, name):
    inp = torch.tensor([float(x), float(y)])
    return int((inp * w[f'{name}.weight']).sum() + w[f'{name}.bias'] >= 0)

def half_adder_sum(x, y, w, prefix):
    inp = torch.tensor([float(x), float(y)])
    l1 = (inp @ w[f'{prefix}.xor.layer1.weight'].T + w[f'{prefix}.xor.layer1.bias'] >= 0).float()
    return int((l1 @ w[f'{prefix}.xor.layer2.weight'].T + w[f'{prefix}.xor.layer2.bias'] >= 0).item())

def half_adder_carry(x, y, w, prefix):
    inp = torch.tensor([float(x), float(y)])
    return int((inp * w[f'{prefix}.carry.weight']).sum() + w[f'{prefix}.carry.bias'] >= 0)

def multiply_2x2(a0, a1, b0, b1, weights):
    """2x2 binary multiplier.

    Inputs: a = (a1, a0), b = (b1, b0) in MSB-first notation

    Returns: (p0, p1, p2, p3) LSB-first, representing a × b

    """
    # Partial products
    pp_a0b0 = and_gate(a0, b0, weights, 'a0b0')
    pp_a1b0 = and_gate(a1, b0, weights, 'a1b0')
    pp_a0b1 = and_gate(a0, b1, weights, 'a0b1')
    pp_a1b1 = and_gate(a1, b1, weights, 'a1b1')

    # p0 = a0 AND b0
    p0 = pp_a0b0

    # p1, c1 = half_add(a1b0, a0b1)
    p1 = half_adder_sum(pp_a1b0, pp_a0b1, weights, 'ha_p1')
    c1 = half_adder_carry(pp_a1b0, pp_a0b1, weights, 'ha_p1')

    # p2, p3 = half_add(a1b1, c1)
    p2 = half_adder_sum(pp_a1b1, c1, weights, 'ha_p2')
    p3 = half_adder_carry(pp_a1b1, c1, weights, 'ha_p2')

    return p0, p1, p2, p3

if __name__ == '__main__':
    w = load_model()
    print('2x2 Binary Multiplier')
    print('a × b tests:')
    for a in range(4):
        for b in range(4):
            a0, a1 = a & 1, (a >> 1) & 1
            b0, b1 = b & 1, (b >> 1) & 1
            p0, p1, p2, p3 = multiply_2x2(a0, a1, b0, b1, w)
            result = p0 + (p1 << 1) + (p2 << 2) + (p3 << 3)
            print(f'{a} × {b} = {result}')