CharlesCNorton commited on
Commit
818e32c
·
1 Parent(s): 5dd640d

Fix model.py interface for pruner compatibility, add optimality note

Browse files

- Change __call__ to accept *bits (variable args) instead of single list
- Update forward() to match pruner's expected signature
- Exhaustive enumeration confirms magnitude 13 is minimum

Files changed (2) hide show
  1. README.md +4 -0
  2. model.py +6 -5
README.md CHANGED
@@ -56,6 +56,10 @@ These are not complements (they don't sum to 1). The gap at HW=4 belongs to neit
56
  | Bias | -5 |
57
  | Total | 9 parameters |
58
 
 
 
 
 
59
  ## Usage
60
 
61
  ```python
 
56
  | Bias | -5 |
57
  | Total | 9 parameters |
58
 
59
+ ## Optimality
60
+
61
+ Exhaustive enumeration of all 27,298,155 weight configurations at magnitudes 0-13 confirms this circuit is **already at minimum magnitude (13)**. There is exactly one valid configuration at magnitude 13, and no valid configurations exist below it.
62
+
63
  ## Usage
64
 
65
  ```python
model.py CHANGED
@@ -21,7 +21,7 @@ class ThresholdMajority:
21
  self.weight = weights_dict['weight']
22
  self.bias = weights_dict['bias']
23
 
24
- def __call__(self, bits):
25
  inputs = torch.tensor([float(b) for b in bits])
26
  weighted_sum = (inputs * self.weight).sum() + self.bias
27
  return (weighted_sum >= 0).float()
@@ -31,19 +31,20 @@ class ThresholdMajority:
31
  return cls(load_file(path))
32
 
33
 
34
- def forward(x, weights):
35
  """
36
  Forward pass with Heaviside activation.
37
 
38
  Args:
39
- x: Input tensor of shape [..., 8]
40
  weights: Dict with 'weight' and 'bias' tensors
41
 
42
  Returns:
43
  1 if majority (5+ of 8) are true, else 0
44
  """
45
- x = torch.as_tensor(x, dtype=torch.float32)
46
- weighted_sum = (x * weights['weight']).sum(dim=-1) + weights['bias']
 
47
  return (weighted_sum >= 0).float()
48
 
49
 
 
21
  self.weight = weights_dict['weight']
22
  self.bias = weights_dict['bias']
23
 
24
+ def __call__(self, *bits):
25
  inputs = torch.tensor([float(b) for b in bits])
26
  weighted_sum = (inputs * self.weight).sum() + self.bias
27
  return (weighted_sum >= 0).float()
 
31
  return cls(load_file(path))
32
 
33
 
34
+ def forward(x0, x1, x2, x3, x4, x5, x6, x7, weights):
35
  """
36
  Forward pass with Heaviside activation.
37
 
38
  Args:
39
+ x0-x7: Individual input bits
40
  weights: Dict with 'weight' and 'bias' tensors
41
 
42
  Returns:
43
  1 if majority (5+ of 8) are true, else 0
44
  """
45
+ x = torch.tensor([float(x0), float(x1), float(x2), float(x3),
46
+ float(x4), float(x5), float(x6), float(x7)])
47
+ weighted_sum = (x * weights['weight']).sum() + weights['bias']
48
  return (weighted_sum >= 0).float()
49
 
50