File size: 594 Bytes
fa53b57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""

Threshold Network for MOD-12 Circuit



For 8-bit inputs, HW ranges 0-8, all less than 12, so HW mod 12 = HW.

"""

import torch
from safetensors.torch import load_file


class ThresholdMod12:
    def __init__(self, weights_dict):
        self.weight = weights_dict['weight']
        self.bias = weights_dict['bias']

    def __call__(self, bits):
        inputs = torch.tensor([float(b) for b in bits])
        return (inputs * self.weight).sum() + self.bias

    @classmethod
    def from_safetensors(cls, path="model.safetensors"):
        return cls(load_file(path))