CharlesCNorton
commited on
Commit
·
329d4e4
1
Parent(s):
41bc964
Rename OPTIMALITY_INDEX to MAGNITUDE_INDEX, add CIRCUITS_TODO, add architecture search
Browse files- Rename OPTIMALITY_INDEX.md to MAGNITUDE_INDEX.md
- Change wording from "optimal" to "minimum magnitude found"
- Add CIRCUITS_TODO.md with 46 circuits to build
- Add prune_architecture method for flat 2-layer architecture search
- CIRCUITS_TODO.md +63 -0
- OPTIMALITY_INDEX.md → MAGNITUDE_INDEX.md +15 -13
- prune.py +241 -3
CIRCUITS_TODO.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Circuits TODO
|
| 2 |
+
|
| 3 |
+
Threshold logic circuits to build.
|
| 4 |
+
|
| 5 |
+
## Voting / Threshold Functions
|
| 6 |
+
1. 1outof4
|
| 7 |
+
2. 2outof4
|
| 8 |
+
3. 3outof4
|
| 9 |
+
4. atmost1outof4
|
| 10 |
+
5. atmost2outof4
|
| 11 |
+
6. atmost3outof4
|
| 12 |
+
7. exactly1outof4
|
| 13 |
+
8. exactly2outof4
|
| 14 |
+
9. exactly3outof4
|
| 15 |
+
10. majority3
|
| 16 |
+
11. majority5
|
| 17 |
+
12. majority7
|
| 18 |
+
13. minority3
|
| 19 |
+
14. minority5
|
| 20 |
+
15. minority7
|
| 21 |
+
|
| 22 |
+
## Comparison
|
| 23 |
+
16. lessthanorequal
|
| 24 |
+
17. greaterthanorequal
|
| 25 |
+
18. comparator4bit
|
| 26 |
+
|
| 27 |
+
## Encoders / Decoders
|
| 28 |
+
19. 2to4decoder
|
| 29 |
+
20. 4to2encoder
|
| 30 |
+
21. 8to3encoder
|
| 31 |
+
22. gray2binary
|
| 32 |
+
23. binary2gray
|
| 33 |
+
24. 7segment
|
| 34 |
+
|
| 35 |
+
## Arithmetic
|
| 36 |
+
25. carrylookahead4bit
|
| 37 |
+
26. multiplier3x3
|
| 38 |
+
27. multiplier4x4
|
| 39 |
+
28. incrementer4bit
|
| 40 |
+
29. decrementer4bit
|
| 41 |
+
30. subtractor4bit
|
| 42 |
+
31. negator4bit
|
| 43 |
+
|
| 44 |
+
## Bit Manipulation
|
| 45 |
+
32. popcount4
|
| 46 |
+
33. popcount8
|
| 47 |
+
34. clz4
|
| 48 |
+
35. clz8
|
| 49 |
+
36. ffs4
|
| 50 |
+
37. reverse4
|
| 51 |
+
38. reverse8
|
| 52 |
+
|
| 53 |
+
## Shift / Rotate
|
| 54 |
+
39. shiftleft4
|
| 55 |
+
40. shiftright4
|
| 56 |
+
41. rotateleft4
|
| 57 |
+
42. rotateright4
|
| 58 |
+
43. barrelshift4
|
| 59 |
+
|
| 60 |
+
## Multiplexers / Demultiplexers
|
| 61 |
+
44. mux8
|
| 62 |
+
45. demux4
|
| 63 |
+
46. demux8
|
OPTIMALITY_INDEX.md → MAGNITUDE_INDEX.md
RENAMED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
-
#
|
| 2 |
|
| 3 |
Results of exhaustive magnitude enumeration on threshold logic circuits.
|
| 4 |
|
| 5 |
## Summary
|
| 6 |
|
| 7 |
-
All circuits listed below have been
|
| 8 |
|
| 9 |
## Single-Layer Gates (Linearly Separable)
|
| 10 |
|
| 11 |
-
| Circuit | Inputs | Params |
|
| 12 |
-
|
| 13 |
| threshold-not | 1 | 2 | 1 | 1 | 5 |
|
| 14 |
| threshold-nor | 2 | 3 | 2 | 1 | 7 |
|
| 15 |
| threshold-implies | 2 | 3 | 2 | 1 | 25 |
|
|
@@ -25,44 +25,46 @@ All circuits listed below have been verified via exhaustive enumeration. "Optima
|
|
| 25 |
|
| 26 |
## Multi-Layer Gates (Not Linearly Separable)
|
| 27 |
|
| 28 |
-
| Circuit | Inputs | Params | Original Mag |
|
| 29 |
-
|
| 30 |
| threshold-xor | 2 | 9 | 10 | 7 | 6 | 30% |
|
| 31 |
| threshold-xnor | 2 | 9 | 9 | 7 | 2 | 22% |
|
| 32 |
| threshold-mux | 3 | 11 | 10 | 7 | 4 | 30% |
|
|
|
|
| 33 |
|
| 34 |
-
##
|
| 35 |
|
| 36 |
-
These repos contain the magnitude
|
| 37 |
|
| 38 |
- `threshold-xor-mag7` - 6 solutions at magnitude 7
|
| 39 |
- `threshold-xnor-mag7` - 2 solutions at magnitude 7
|
| 40 |
- `threshold-mux-mag7` - 4 solutions at magnitude 7
|
|
|
|
| 41 |
|
| 42 |
## Pending / In Progress
|
| 43 |
|
| 44 |
| Circuit | Params | Status |
|
| 45 |
|---------|--------|--------|
|
| 46 |
-
| threshold-halfadder | 12 | Running (expected
|
| 47 |
| threshold-mod4 | 9 | Running |
|
| 48 |
| threshold-biimplies | 9 | Not yet tested (same as XNOR) |
|
| 49 |
| threshold-halfsubtractor | 12 | Not yet tested |
|
| 50 |
|
| 51 |
## Methodology
|
| 52 |
|
| 53 |
-
Exhaustive search enumerates all integer weight configurations by magnitude level (0, 1, 2, ...) until valid solutions are found. This
|
| 54 |
|
| 55 |
For circuits with >12 parameters, exhaustive search becomes impractical. Use evolutionary or simulated annealing instead.
|
| 56 |
|
| 57 |
## Key Findings
|
| 58 |
|
| 59 |
-
1. **Single-layer threshold gates have unique
|
| 60 |
|
| 61 |
-
2. **Multi-layer gates can have solution families.** XOR has 6 solutions at magnitude 7,
|
| 62 |
|
| 63 |
3. **Non-linearly-separable functions benefit most from optimization.** XOR/XNOR/MUX achieved 22-30% magnitude reduction.
|
| 64 |
|
| 65 |
-
4. **
|
| 66 |
|
| 67 |
## Last Updated
|
| 68 |
|
|
|
|
| 1 |
+
# Magnitude Index
|
| 2 |
|
| 3 |
Results of exhaustive magnitude enumeration on threshold logic circuits.
|
| 4 |
|
| 5 |
## Summary
|
| 6 |
|
| 7 |
+
All circuits listed below have been tested via exhaustive enumeration. "Min Mag" is the minimum magnitude at which valid configurations were found.
|
| 8 |
|
| 9 |
## Single-Layer Gates (Linearly Separable)
|
| 10 |
|
| 11 |
+
| Circuit | Inputs | Params | Min Mag | Solutions | Configs Tested |
|
| 12 |
+
|---------|--------|--------|---------|-----------|----------------|
|
| 13 |
| threshold-not | 1 | 2 | 1 | 1 | 5 |
|
| 14 |
| threshold-nor | 2 | 3 | 2 | 1 | 7 |
|
| 15 |
| threshold-implies | 2 | 3 | 2 | 1 | 25 |
|
|
|
|
| 25 |
|
| 26 |
## Multi-Layer Gates (Not Linearly Separable)
|
| 27 |
|
| 28 |
+
| Circuit | Inputs | Params | Original Mag | Min Mag | Solutions | Reduction |
|
| 29 |
+
|---------|--------|--------|--------------|---------|-----------|-----------|
|
| 30 |
| threshold-xor | 2 | 9 | 10 | 7 | 6 | 30% |
|
| 31 |
| threshold-xnor | 2 | 9 | 9 | 7 | 2 | 22% |
|
| 32 |
| threshold-mux | 3 | 11 | 10 | 7 | 4 | 30% |
|
| 33 |
+
| threshold-xor3 | 3 | 16 | 14 | 10 | 18 | 29% |
|
| 34 |
|
| 35 |
+
## Magnitude-Minimized Variants
|
| 36 |
|
| 37 |
+
These repos contain the minimum-magnitude weights found:
|
| 38 |
|
| 39 |
- `threshold-xor-mag7` - 6 solutions at magnitude 7
|
| 40 |
- `threshold-xnor-mag7` - 2 solutions at magnitude 7
|
| 41 |
- `threshold-mux-mag7` - 4 solutions at magnitude 7
|
| 42 |
+
- `threshold-xor3-mag10` - 18 solutions at magnitude 10 (flat architecture)
|
| 43 |
|
| 44 |
## Pending / In Progress
|
| 45 |
|
| 46 |
| Circuit | Params | Status |
|
| 47 |
|---------|--------|--------|
|
| 48 |
+
| threshold-halfadder | 12 | Running (expected min: 11) |
|
| 49 |
| threshold-mod4 | 9 | Running |
|
| 50 |
| threshold-biimplies | 9 | Not yet tested (same as XNOR) |
|
| 51 |
| threshold-halfsubtractor | 12 | Not yet tested |
|
| 52 |
|
| 53 |
## Methodology
|
| 54 |
|
| 55 |
+
Exhaustive search enumerates all integer weight configurations by magnitude level (0, 1, 2, ...) until valid solutions are found. This finds the minimum magnitude within the search space.
|
| 56 |
|
| 57 |
For circuits with >12 parameters, exhaustive search becomes impractical. Use evolutionary or simulated annealing instead.
|
| 58 |
|
| 59 |
## Key Findings
|
| 60 |
|
| 61 |
+
1. **Single-layer threshold gates appear to have unique minimum-magnitude representations.** All tested single-layer gates have exactly 1 solution at their minimum magnitude.
|
| 62 |
|
| 63 |
+
2. **Multi-layer gates can have solution families.** XOR has 6 solutions at magnitude 7, XOR3 has 18 solutions at magnitude 10.
|
| 64 |
|
| 65 |
3. **Non-linearly-separable functions benefit most from optimization.** XOR/XNOR/MUX achieved 22-30% magnitude reduction.
|
| 66 |
|
| 67 |
+
4. **Architecture matters.** XOR3 flat architecture (mag 10) beats cascade architecture (mag 14) by 29%.
|
| 68 |
|
| 69 |
## Last Updated
|
| 70 |
|
prune.py
CHANGED
|
@@ -163,6 +163,11 @@ class Config:
|
|
| 163 |
topology_remove_prob: float = 0.2
|
| 164 |
topology_add_prob: float = 0.1
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
sensitivity_samples: int = 1000
|
| 167 |
|
| 168 |
depth_max_collapse: int = 3
|
|
@@ -2921,6 +2926,231 @@ def prune_exhaustive_sparse(circuit: AdaptiveCircuit, evaluator: BatchedEvaluato
|
|
| 2921 |
)
|
| 2922 |
|
| 2923 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2924 |
def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]:
|
| 2925 |
"""Run all enabled pruning methods."""
|
| 2926 |
|
|
@@ -2965,6 +3195,7 @@ def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneRes
|
|
| 2965 |
('fanin', cfg.run_fanin, lambda: prune_fanin(circuit, evaluator, cfg)),
|
| 2966 |
('exhaustive_mag', cfg.run_exhaustive_mag, lambda: prune_exhaustive_mag(circuit, evaluator, cfg)),
|
| 2967 |
('exhaustive_sparse', cfg.run_exhaustive_sparse, lambda: prune_exhaustive_sparse(circuit, evaluator, cfg)),
|
|
|
|
| 2968 |
('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
|
| 2969 |
('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
|
| 2970 |
]
|
|
@@ -3140,6 +3371,9 @@ def main():
|
|
| 3140 |
parser.add_argument('--fanin-target', type=int, default=4)
|
| 3141 |
parser.add_argument('--sparse-max-weight', type=int, default=3, help='Max weight magnitude for sparse search')
|
| 3142 |
parser.add_argument('--exhaustive-max-params', type=int, default=12, help='Max params for exhaustive search')
|
|
|
|
|
|
|
|
|
|
| 3143 |
|
| 3144 |
args = parser.parse_args()
|
| 3145 |
|
|
@@ -3163,14 +3397,17 @@ def main():
|
|
| 3163 |
annealing_parallel_chains=args.sa_chains,
|
| 3164 |
fanin_target=args.fanin_target,
|
| 3165 |
sparse_max_weight=args.sparse_max_weight,
|
| 3166 |
-
exhaustive_max_params=args.exhaustive_max_params
|
|
|
|
|
|
|
|
|
|
| 3167 |
)
|
| 3168 |
|
| 3169 |
if args.methods:
|
| 3170 |
all_methods = ['magnitude', 'zero', 'quantize', 'evolutionary', 'annealing',
|
| 3171 |
'structural', 'topology', 'sensitivity', 'weight_sharing',
|
| 3172 |
'depth', 'gate_subst', 'symmetry', 'fanin',
|
| 3173 |
-
'exhaustive_mag', 'exhaustive_sparse']
|
| 3174 |
for m in all_methods:
|
| 3175 |
setattr(cfg, f'run_{m}', False)
|
| 3176 |
|
|
@@ -3191,7 +3428,8 @@ def main():
|
|
| 3191 |
'sym': 'symmetry', 'symmetry': 'symmetry',
|
| 3192 |
'fanin': 'fanin', 'fan': 'fanin',
|
| 3193 |
'exhaustive_mag': 'exhaustive_mag', 'exh_mag': 'exhaustive_mag', 'exh': 'exhaustive_mag', 'brute': 'exhaustive_mag',
|
| 3194 |
-
'exhaustive_sparse': 'exhaustive_sparse', 'exh_sparse': 'exhaustive_sparse', 'sparse': 'exhaustive_sparse'
|
|
|
|
| 3195 |
}
|
| 3196 |
if m in method_map:
|
| 3197 |
setattr(cfg, f'run_{method_map[m]}', True)
|
|
|
|
| 163 |
topology_remove_prob: float = 0.2
|
| 164 |
topology_add_prob: float = 0.1
|
| 165 |
|
| 166 |
+
run_architecture: bool = False
|
| 167 |
+
arch_hidden_neurons: int = 3
|
| 168 |
+
arch_max_weight: int = 3
|
| 169 |
+
arch_max_mag: int = 20
|
| 170 |
+
|
| 171 |
sensitivity_samples: int = 1000
|
| 172 |
|
| 173 |
depth_max_collapse: int = 3
|
|
|
|
| 2926 |
)
|
| 2927 |
|
| 2928 |
|
| 2929 |
+
def prune_architecture(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2930 |
+
"""
|
| 2931 |
+
Architecture search - find optimal flat 2-layer architecture.
|
| 2932 |
+
|
| 2933 |
+
Searches for a flat architecture with N hidden neurons that computes
|
| 2934 |
+
the same function as the circuit, potentially at lower magnitude.
|
| 2935 |
+
|
| 2936 |
+
Parameters controlled by:
|
| 2937 |
+
cfg.arch_hidden_neurons: number of hidden neurons (default 3)
|
| 2938 |
+
cfg.arch_max_weight: max absolute weight value (default 3)
|
| 2939 |
+
cfg.arch_max_mag: max magnitude to search (default 20)
|
| 2940 |
+
"""
|
| 2941 |
+
start = time.perf_counter()
|
| 2942 |
+
original = circuit.stats()
|
| 2943 |
+
|
| 2944 |
+
n_hidden = cfg.arch_hidden_neurons
|
| 2945 |
+
n_inputs = circuit.spec.inputs
|
| 2946 |
+
n_outputs = circuit.spec.outputs
|
| 2947 |
+
max_weight = cfg.arch_max_weight
|
| 2948 |
+
max_mag = cfg.arch_max_mag
|
| 2949 |
+
|
| 2950 |
+
# Parameters: n_hidden * (n_inputs + 1) + n_outputs * (n_hidden + 1)
|
| 2951 |
+
n_params = n_hidden * (n_inputs + 1) + n_outputs * (n_hidden + 1)
|
| 2952 |
+
|
| 2953 |
+
if cfg.verbose:
|
| 2954 |
+
print(f" [ARCH] Architecture search")
|
| 2955 |
+
print(f" [ARCH] Hidden neurons: {n_hidden}")
|
| 2956 |
+
print(f" [ARCH] Inputs: {n_inputs}, Outputs: {n_outputs}")
|
| 2957 |
+
print(f" [ARCH] Parameters: {n_params}")
|
| 2958 |
+
print(f" [ARCH] Max weight: {max_weight}, Max magnitude: {max_mag}")
|
| 2959 |
+
print(f" [ARCH] Searching by magnitude level...")
|
| 2960 |
+
|
| 2961 |
+
test_inputs = circuit.test_inputs
|
| 2962 |
+
test_expected = circuit.test_expected
|
| 2963 |
+
|
| 2964 |
+
def eval_flat_architecture(configs: torch.Tensor) -> torch.Tensor:
|
| 2965 |
+
"""Evaluate batch of flat architecture configs."""
|
| 2966 |
+
batch_size = configs.shape[0]
|
| 2967 |
+
|
| 2968 |
+
# Extract weights for hidden layer
|
| 2969 |
+
idx = 0
|
| 2970 |
+
hidden_weights = []
|
| 2971 |
+
hidden_biases = []
|
| 2972 |
+
for h in range(n_hidden):
|
| 2973 |
+
w = configs[:, idx:idx+n_inputs]
|
| 2974 |
+
idx += n_inputs
|
| 2975 |
+
b = configs[:, idx:idx+1]
|
| 2976 |
+
idx += 1
|
| 2977 |
+
hidden_weights.append(w)
|
| 2978 |
+
hidden_biases.append(b)
|
| 2979 |
+
|
| 2980 |
+
# Extract weights for output layer
|
| 2981 |
+
output_weights = []
|
| 2982 |
+
output_biases = []
|
| 2983 |
+
for o in range(n_outputs):
|
| 2984 |
+
w = configs[:, idx:idx+n_hidden]
|
| 2985 |
+
idx += n_hidden
|
| 2986 |
+
b = configs[:, idx:idx+1]
|
| 2987 |
+
idx += 1
|
| 2988 |
+
output_weights.append(w)
|
| 2989 |
+
output_biases.append(b)
|
| 2990 |
+
|
| 2991 |
+
# Compute hidden activations for all test inputs
|
| 2992 |
+
# test_inputs: [n_cases, n_inputs]
|
| 2993 |
+
# hidden_weights[h]: [batch, n_inputs]
|
| 2994 |
+
n_cases = test_inputs.shape[0]
|
| 2995 |
+
|
| 2996 |
+
hidden_acts = []
|
| 2997 |
+
for h in range(n_hidden):
|
| 2998 |
+
# [batch, 1, n_inputs] * [1, n_cases, n_inputs] -> sum -> [batch, n_cases]
|
| 2999 |
+
act = (hidden_weights[h].unsqueeze(1) * test_inputs.unsqueeze(0)).sum(dim=2) + hidden_biases[h]
|
| 3000 |
+
act = (act >= 0).float()
|
| 3001 |
+
hidden_acts.append(act)
|
| 3002 |
+
|
| 3003 |
+
hidden_stack = torch.stack(hidden_acts, dim=2) # [batch, n_cases, n_hidden]
|
| 3004 |
+
|
| 3005 |
+
# Compute output
|
| 3006 |
+
outputs = []
|
| 3007 |
+
for o in range(n_outputs):
|
| 3008 |
+
out = (hidden_stack * output_weights[o].unsqueeze(1)).sum(dim=2) + output_biases[o]
|
| 3009 |
+
out = (out >= 0).float()
|
| 3010 |
+
outputs.append(out)
|
| 3011 |
+
|
| 3012 |
+
if n_outputs == 1:
|
| 3013 |
+
predicted = outputs[0]
|
| 3014 |
+
expected = test_expected.squeeze()
|
| 3015 |
+
else:
|
| 3016 |
+
predicted = torch.stack(outputs, dim=2)
|
| 3017 |
+
expected = test_expected
|
| 3018 |
+
|
| 3019 |
+
correct = (predicted == expected.unsqueeze(0)).float().mean(dim=1)
|
| 3020 |
+
if n_outputs > 1:
|
| 3021 |
+
correct = correct.mean(dim=1)
|
| 3022 |
+
|
| 3023 |
+
return correct
|
| 3024 |
+
|
| 3025 |
+
# Partition-based enumeration
|
| 3026 |
+
@lru_cache(maxsize=None)
|
| 3027 |
+
def partitions(total: int, n_slots: int, max_val: int) -> list:
|
| 3028 |
+
if n_slots == 0:
|
| 3029 |
+
return [()] if total == 0 else []
|
| 3030 |
+
if n_slots == 1:
|
| 3031 |
+
return [(total,)] if total <= max_val else []
|
| 3032 |
+
result = []
|
| 3033 |
+
for v in range(min(total, max_val) + 1):
|
| 3034 |
+
for rest in partitions(total - v, n_slots - 1, max_val):
|
| 3035 |
+
result.append((v,) + rest)
|
| 3036 |
+
return result
|
| 3037 |
+
|
| 3038 |
+
def signs_for_partition(partition: tuple) -> torch.Tensor:
|
| 3039 |
+
n = len(partition)
|
| 3040 |
+
nonzero_idx = [i for i, v in enumerate(partition) if v != 0]
|
| 3041 |
+
k = len(nonzero_idx)
|
| 3042 |
+
|
| 3043 |
+
if k == 0:
|
| 3044 |
+
return torch.zeros(1, n, device=cfg.device, dtype=torch.float32)
|
| 3045 |
+
|
| 3046 |
+
n_patterns = 2 ** k
|
| 3047 |
+
configs = torch.zeros(n_patterns, n, device=cfg.device, dtype=torch.float32)
|
| 3048 |
+
|
| 3049 |
+
for i, idx in enumerate(nonzero_idx):
|
| 3050 |
+
signs = ((torch.arange(n_patterns, device=cfg.device) >> i) & 1) * 2 - 1
|
| 3051 |
+
configs[:, idx] = signs.float() * partition[idx]
|
| 3052 |
+
|
| 3053 |
+
return configs
|
| 3054 |
+
|
| 3055 |
+
def generate_at_magnitude(target_mag: int):
|
| 3056 |
+
all_configs = []
|
| 3057 |
+
for partition in partitions(target_mag, n_params, max_weight):
|
| 3058 |
+
signed = signs_for_partition(partition)
|
| 3059 |
+
all_configs.append(signed)
|
| 3060 |
+
if all_configs:
|
| 3061 |
+
return torch.cat(all_configs, dim=0)
|
| 3062 |
+
return torch.zeros(0, n_params, device=cfg.device)
|
| 3063 |
+
|
| 3064 |
+
total_tested = 0
|
| 3065 |
+
all_solutions = []
|
| 3066 |
+
optimal_mag = None
|
| 3067 |
+
|
| 3068 |
+
for target_mag in range(1, max_mag + 1):
|
| 3069 |
+
mag_start = time.perf_counter()
|
| 3070 |
+
|
| 3071 |
+
configs = generate_at_magnitude(target_mag)
|
| 3072 |
+
n_configs = configs.shape[0]
|
| 3073 |
+
|
| 3074 |
+
if n_configs == 0:
|
| 3075 |
+
continue
|
| 3076 |
+
|
| 3077 |
+
if cfg.verbose:
|
| 3078 |
+
print(f" Magnitude {target_mag}: {n_configs:,} configs...", end=" ", flush=True)
|
| 3079 |
+
|
| 3080 |
+
# Batch evaluate
|
| 3081 |
+
batch_size = 500000
|
| 3082 |
+
valid_configs = []
|
| 3083 |
+
|
| 3084 |
+
for i in range(0, n_configs, batch_size):
|
| 3085 |
+
batch = configs[i:i+batch_size]
|
| 3086 |
+
fitness = eval_flat_architecture(batch)
|
| 3087 |
+
valid_mask = fitness >= cfg.fitness_threshold
|
| 3088 |
+
if valid_mask.any():
|
| 3089 |
+
valid_configs.extend(batch[valid_mask].cpu().tolist())
|
| 3090 |
+
|
| 3091 |
+
total_tested += n_configs
|
| 3092 |
+
mag_time = time.perf_counter() - mag_start
|
| 3093 |
+
|
| 3094 |
+
if valid_configs:
|
| 3095 |
+
if cfg.verbose:
|
| 3096 |
+
print(f"FOUND {len(valid_configs)} solutions! ({mag_time:.1f}s)")
|
| 3097 |
+
|
| 3098 |
+
optimal_mag = target_mag
|
| 3099 |
+
all_solutions = valid_configs
|
| 3100 |
+
|
| 3101 |
+
if cfg.verbose:
|
| 3102 |
+
print(f" [ARCH] Optimal magnitude: {optimal_mag}")
|
| 3103 |
+
print(f" [ARCH] Solutions found: {len(all_solutions)}")
|
| 3104 |
+
print(f" [ARCH] First solution:")
|
| 3105 |
+
sol = all_solutions[0]
|
| 3106 |
+
idx = 0
|
| 3107 |
+
for h in range(n_hidden):
|
| 3108 |
+
w = sol[idx:idx+n_inputs]
|
| 3109 |
+
idx += n_inputs
|
| 3110 |
+
b = sol[idx]
|
| 3111 |
+
idx += 1
|
| 3112 |
+
print(f" h{h+1}: w={[int(x) for x in w]}, b={int(b)}")
|
| 3113 |
+
for o in range(n_outputs):
|
| 3114 |
+
w = sol[idx:idx+n_hidden]
|
| 3115 |
+
idx += n_hidden
|
| 3116 |
+
b = sol[idx]
|
| 3117 |
+
idx += 1
|
| 3118 |
+
print(f" out{o+1}: w={[int(x) for x in w]}, b={int(b)}")
|
| 3119 |
+
|
| 3120 |
+
break
|
| 3121 |
+
else:
|
| 3122 |
+
if cfg.verbose:
|
| 3123 |
+
print(f"none ({mag_time:.1f}s)")
|
| 3124 |
+
|
| 3125 |
+
elapsed = time.perf_counter() - start
|
| 3126 |
+
|
| 3127 |
+
if cfg.verbose:
|
| 3128 |
+
print(f" [ARCH COMPLETE]")
|
| 3129 |
+
print(f" - Configurations tested: {total_tested:,}")
|
| 3130 |
+
print(f" - Optimal magnitude: {optimal_mag if optimal_mag else 'none found'}")
|
| 3131 |
+
print(f" - Original magnitude: {original['magnitude']:.0f}")
|
| 3132 |
+
if optimal_mag:
|
| 3133 |
+
print(f" - Reduction: {(1 - optimal_mag/original['magnitude'])*100:.1f}%")
|
| 3134 |
+
print(f" - Solutions: {len(all_solutions)}")
|
| 3135 |
+
print(f" - Time: {elapsed:.1f}s")
|
| 3136 |
+
|
| 3137 |
+
return PruneResult(
|
| 3138 |
+
method='architecture',
|
| 3139 |
+
original_stats=original,
|
| 3140 |
+
final_stats=original, # We don't change the original weights
|
| 3141 |
+
final_weights=circuit.clone_weights(),
|
| 3142 |
+
fitness=evaluator.evaluate_single(circuit.weights),
|
| 3143 |
+
time_seconds=elapsed,
|
| 3144 |
+
metadata={
|
| 3145 |
+
'hidden_neurons': n_hidden,
|
| 3146 |
+
'optimal_magnitude': optimal_mag,
|
| 3147 |
+
'total_tested': total_tested,
|
| 3148 |
+
'solutions_count': len(all_solutions),
|
| 3149 |
+
'all_solutions': all_solutions[:100]
|
| 3150 |
+
}
|
| 3151 |
+
)
|
| 3152 |
+
|
| 3153 |
+
|
| 3154 |
def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]:
|
| 3155 |
"""Run all enabled pruning methods."""
|
| 3156 |
|
|
|
|
| 3195 |
('fanin', cfg.run_fanin, lambda: prune_fanin(circuit, evaluator, cfg)),
|
| 3196 |
('exhaustive_mag', cfg.run_exhaustive_mag, lambda: prune_exhaustive_mag(circuit, evaluator, cfg)),
|
| 3197 |
('exhaustive_sparse', cfg.run_exhaustive_sparse, lambda: prune_exhaustive_sparse(circuit, evaluator, cfg)),
|
| 3198 |
+
('architecture', cfg.run_architecture, lambda: prune_architecture(circuit, evaluator, cfg)),
|
| 3199 |
('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
|
| 3200 |
('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
|
| 3201 |
]
|
|
|
|
| 3371 |
parser.add_argument('--fanin-target', type=int, default=4)
|
| 3372 |
parser.add_argument('--sparse-max-weight', type=int, default=3, help='Max weight magnitude for sparse search')
|
| 3373 |
parser.add_argument('--exhaustive-max-params', type=int, default=12, help='Max params for exhaustive search')
|
| 3374 |
+
parser.add_argument('--arch-hidden', type=int, default=3, help='Number of hidden neurons for architecture search')
|
| 3375 |
+
parser.add_argument('--arch-max-weight', type=int, default=3, help='Max weight for architecture search')
|
| 3376 |
+
parser.add_argument('--arch-max-mag', type=int, default=20, help='Max magnitude to search for architecture')
|
| 3377 |
|
| 3378 |
args = parser.parse_args()
|
| 3379 |
|
|
|
|
| 3397 |
annealing_parallel_chains=args.sa_chains,
|
| 3398 |
fanin_target=args.fanin_target,
|
| 3399 |
sparse_max_weight=args.sparse_max_weight,
|
| 3400 |
+
exhaustive_max_params=args.exhaustive_max_params,
|
| 3401 |
+
arch_hidden_neurons=args.arch_hidden,
|
| 3402 |
+
arch_max_weight=args.arch_max_weight,
|
| 3403 |
+
arch_max_mag=args.arch_max_mag
|
| 3404 |
)
|
| 3405 |
|
| 3406 |
if args.methods:
|
| 3407 |
all_methods = ['magnitude', 'zero', 'quantize', 'evolutionary', 'annealing',
|
| 3408 |
'structural', 'topology', 'sensitivity', 'weight_sharing',
|
| 3409 |
'depth', 'gate_subst', 'symmetry', 'fanin',
|
| 3410 |
+
'exhaustive_mag', 'exhaustive_sparse', 'architecture']
|
| 3411 |
for m in all_methods:
|
| 3412 |
setattr(cfg, f'run_{m}', False)
|
| 3413 |
|
|
|
|
| 3428 |
'sym': 'symmetry', 'symmetry': 'symmetry',
|
| 3429 |
'fanin': 'fanin', 'fan': 'fanin',
|
| 3430 |
'exhaustive_mag': 'exhaustive_mag', 'exh_mag': 'exhaustive_mag', 'exh': 'exhaustive_mag', 'brute': 'exhaustive_mag',
|
| 3431 |
+
'exhaustive_sparse': 'exhaustive_sparse', 'exh_sparse': 'exhaustive_sparse', 'sparse': 'exhaustive_sparse',
|
| 3432 |
+
'architecture': 'architecture', 'arch': 'architecture'
|
| 3433 |
}
|
| 3434 |
if m in method_map:
|
| 3435 |
setattr(cfg, f'run_{method_map[m]}', True)
|