File size: 2,840 Bytes
8ec63f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""

Pruned Threshold Network for Parity Computation



A minimal ternary threshold network (8->11->3->1) that computes 8-bit parity.

Pruned from the original 8->32->16->1 architecture with 83.3% parameter reduction.

"""

import torch
import torch.nn as nn


class PrunedThresholdNetwork(nn.Module):
    """

    Pruned binary threshold network with ternary weights.



    Architecture: 8 -> 11 -> 3 -> 1

    Weights: {-1, 0, +1}

    Activation: Heaviside (x >= 0 -> 1, else 0)

    """

    def __init__(self, n_bits=8, hidden1=11, hidden2=3):
        super().__init__()
        self.n_bits = n_bits
        self.hidden1 = hidden1
        self.hidden2 = hidden2

        self.layer1_weight = nn.Parameter(torch.zeros(hidden1, n_bits))
        self.layer1_bias = nn.Parameter(torch.zeros(hidden1))
        self.layer2_weight = nn.Parameter(torch.zeros(hidden2, hidden1))
        self.layer2_bias = nn.Parameter(torch.zeros(hidden2))
        self.output_weight = nn.Parameter(torch.zeros(1, hidden2))
        self.output_bias = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        """Forward pass with Heaviside activation."""
        x = x.float()
        x = (torch.nn.functional.linear(x, self.layer1_weight, self.layer1_bias) >= 0).float()
        x = (torch.nn.functional.linear(x, self.layer2_weight, self.layer2_bias) >= 0).float()
        x = (torch.nn.functional.linear(x, self.output_weight, self.output_bias) >= 0).float()
        return x.squeeze(-1)

    @classmethod
    def from_safetensors(cls, path):
        """Load model from SafeTensors file."""
        from safetensors.torch import load_file

        weights = load_file(path)

        hidden1 = weights['layer1.weight'].shape[0]
        hidden2 = weights['layer2.weight'].shape[0]
        n_bits = weights['layer1.weight'].shape[1]

        model = cls(n_bits=n_bits, hidden1=hidden1, hidden2=hidden2)
        model.layer1_weight.data = weights['layer1.weight']
        model.layer1_bias.data = weights['layer1.bias']
        model.layer2_weight.data = weights['layer2.weight']
        model.layer2_bias.data = weights['layer2.bias']
        model.output_weight.data = weights['output.weight']
        model.output_bias.data = weights['output.bias']

        return model


def parity(x):
    """Ground truth parity function."""
    return (x.sum(dim=-1) % 2).float()


if __name__ == '__main__':
    model = PrunedThresholdNetwork.from_safetensors('model.safetensors')

    # Test all 256 inputs
    all_inputs = torch.tensor([[int(b) for b in format(i, '08b')] for i in range(256)], dtype=torch.float32)
    outputs = model(all_inputs)
    expected = parity(all_inputs)

    correct = (outputs == expected).sum().item()
    print(f'Accuracy: {correct}/256 ({100*correct/256:.1f}%)')