File size: 3,464 Bytes
278368f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""

Inference code for mod5-verified threshold network.



This network computes MOD-5 (Hamming weight mod 5) on 8-bit binary inputs.

"""

import torch
import torch.nn as nn
from safetensors.torch import load_file


def heaviside(x):
    """Heaviside step function: 1 if x >= 0, else 0."""
    return (x >= 0).float()


class Mod5Network(nn.Module):
    """

    Verified threshold network for MOD-5 computation.



    Architecture: 8 -> 9 -> 4 -> 5

    - Layer 1: Thermometer encoding (9 neurons detect HW >= k)

    - Layer 2: MOD-5 detection using (1,1,1,1,-4) weight pattern

    - Output: 5-class classification

    """

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(8, 9)
        self.layer2 = nn.Linear(9, 4)
        self.output = nn.Linear(4, 5)

    def forward(self, x):
        """Forward pass with Heaviside activation."""
        x = x.float()
        x = heaviside(self.layer1(x))
        x = heaviside(self.layer2(x))
        x = self.output(x)
        return x

    def predict(self, x):
        """Get predicted class (0, 1, 2, 3, or 4)."""
        return self.forward(x).argmax(dim=-1)

    @classmethod
    def from_safetensors(cls, path):
        """Load model from safetensors file."""
        model = cls()
        weights = load_file(path)

        model.layer1.weight.data = weights['layer1.weight']
        model.layer1.bias.data = weights['layer1.bias']
        model.layer2.weight.data = weights['layer2.weight']
        model.layer2.bias.data = weights['layer2.bias']
        model.output.weight.data = weights['output.weight']
        model.output.bias.data = weights['output.bias']

        return model


def mod5_reference(x):
    """Reference implementation: Hamming weight mod 5."""
    return (x.sum(dim=-1) % 5).long()


def verify(model, verbose=True):
    """Verify model on all 256 inputs."""
    inputs = torch.zeros(256, 8)
    for i in range(256):
        for j in range(8):
            inputs[i, j] = (i >> j) & 1

    targets = mod5_reference(inputs)
    predictions = model.predict(inputs)

    correct = (predictions == targets).sum().item()

    if verbose:
        print(f"Verification: {correct}/256 ({100*correct/256:.1f}%)")

        if correct < 256:
            errors = (predictions != targets).nonzero(as_tuple=True)[0]
            print(f"Errors at indices: {errors[:10].tolist()}")

    return correct == 256


def demo():
    """Demonstration of MOD-5 computation."""
    print("Loading mod5-verified model...")
    model = Mod5Network.from_safetensors('model.safetensors')

    print("\nVerifying on all 256 inputs...")
    verify(model)

    print("\nExample predictions:")
    test_cases = [
        [0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1],
    ]

    for bits in test_cases:
        x = torch.tensor([bits], dtype=torch.float32)
        hw = sum(bits)
        pred = model.predict(x).item()
        expected = hw % 5
        status = "OK" if pred == expected else "ERROR"
        print(f"  {bits} -> HW={hw}, pred={pred}, expected={expected} [{status}]")


if __name__ == '__main__':
    demo()