File size: 3,034 Bytes
156dd54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""

Threshold Network for Parity Computation



A ternary threshold network that computes the parity (XOR) of 8 binary inputs.

Weights are constrained to {-1, 0, +1} and activations use the Heaviside step function.

"""

import json
import torch
import torch.nn as nn


class ThresholdNetwork(nn.Module):
    """

    Binary threshold network with ternary weights.



    Architecture: 8 -> 32 -> 16 -> 1

    Weights: {-1, 0, +1}

    Activation: Heaviside (x >= 0 -> 1, else 0)

    """

    def __init__(self, n_bits=8, hidden1=32, hidden2=16):
        super().__init__()
        self.n_bits = n_bits
        self.hidden1 = hidden1
        self.hidden2 = hidden2

        self.layer1_weight = nn.Parameter(torch.zeros(hidden1, n_bits))
        self.layer1_bias = nn.Parameter(torch.zeros(hidden1))
        self.layer2_weight = nn.Parameter(torch.zeros(hidden2, hidden1))
        self.layer2_bias = nn.Parameter(torch.zeros(hidden2))
        self.output_weight = nn.Parameter(torch.zeros(1, hidden2))
        self.output_bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        """Forward pass with Heaviside activation."""
        x = x.float()
        x = (torch.nn.functional.linear(x, self.layer1_weight, self.layer1_bias) >= 0).float()
        x = (torch.nn.functional.linear(x, self.layer2_weight, self.layer2_bias) >= 0).float()
        x = (torch.nn.functional.linear(x, self.output_weight, self.output_bias) >= 0).float()
        return x.squeeze(-1)

    @classmethod
    def from_safetensors(cls, path):
        """Load model from SafeTensors file."""
        from safetensors.torch import load_file

        weights = load_file(path)

        hidden1 = weights['layer1.weight'].shape[0]
        hidden2 = weights['layer2.weight'].shape[0]
        n_bits = weights['layer1.weight'].shape[1]

        model = cls(n_bits=n_bits, hidden1=hidden1, hidden2=hidden2)
        model.layer1_weight.data = weights['layer1.weight']
        model.layer1_bias.data = weights['layer1.bias']
        model.layer2_weight.data = weights['layer2.weight']
        model.layer2_bias.data = weights['layer2.bias']
        model.output_weight.data = weights['output.weight']
        model.output_bias.data = weights['output.bias']

        return model


def parity(x):
    """Ground truth parity function."""
    return (x.sum(dim=-1) % 2).float()


if __name__ == '__main__':
    model = ThresholdNetwork.from_safetensors('model.safetensors')

    test_inputs = torch.tensor([
        [0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1],
    ], dtype=torch.float32)

    outputs = model(test_inputs)
    expected = parity(test_inputs)

    print("Input -> Output (Expected)")
    for i in range(len(test_inputs)):
        bits = test_inputs[i].int().tolist()
        print(f"{bits} -> {int(outputs[i].item())} ({int(expected[i].item())})")