CharlesCNorton
commited on
Commit
·
43d198d
1
Parent(s):
ff1ca32
Add magnitude-ordered exhaustive search with GPU-vectorized evaluation
Browse files- Exhaustive search now searches by magnitude level (0, 1, 2, ...) and stops
at first valid magnitude, finding ALL solutions at that level
- Add matrix-style weight detection for circuits using layer1.weight [H,I] format
- Add GPU-vectorized evaluation for matrix-style circuits (massive speedup)
- Add MUX/DEMUX native forward function detection
- Report full metrics for solutions: magnitude, nonzero count, max|w|, sparsity
- Fix CIRCUITS_PATH to point to threshold-logic-circuits
prune.py
CHANGED
|
@@ -47,7 +47,7 @@ except ImportError:
|
|
| 47 |
|
| 48 |
warnings.filterwarnings('ignore')
|
| 49 |
|
| 50 |
-
CIRCUITS_PATH = Path('D:/threshold-circuits')
|
| 51 |
RESULTS_PATH = CIRCUITS_PATH / 'pruned_results'
|
| 52 |
|
| 53 |
|
|
@@ -534,6 +534,17 @@ class AdaptiveCircuit:
|
|
| 534 |
model = cls(weights)
|
| 535 |
return list(model(int(inputs[0]), int(inputs[1])))
|
| 536 |
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
for attr_name in dir(module):
|
| 539 |
if attr_name.startswith('_'):
|
|
@@ -914,6 +925,11 @@ class BatchedEvaluator:
|
|
| 914 |
self.n_inputs = circuit.spec.inputs
|
| 915 |
self.n_outputs = circuit.spec.outputs
|
| 916 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 917 |
self.use_vmap = VMAP_AVAILABLE and cfg.device == 'cuda' and not circuit.has_native
|
| 918 |
self.vmap_forward = None
|
| 919 |
self.vmapped_fn = None
|
|
@@ -935,7 +951,87 @@ class BatchedEvaluator:
|
|
| 935 |
batched_status = "yes" if self.batched_ready and not self.use_native_eval else "no"
|
| 936 |
native_status = "forced" if self.use_native_eval else ("available" if circuit.has_native else "none")
|
| 937 |
seq_deps = "yes" if self.has_sequential_deps else "no"
|
| 938 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
|
| 940 |
def _detect_sequential_dependencies(self) -> bool:
|
| 941 |
"""Detect if circuit has neurons that depend on other neurons' outputs."""
|
|
@@ -1086,6 +1182,13 @@ class BatchedEvaluator:
|
|
| 1086 |
if pop_size > self.max_batch:
|
| 1087 |
return self._evaluate_chunked(population)
|
| 1088 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1089 |
if self.use_native_eval:
|
| 1090 |
return self._evaluate_native_parallel(population)
|
| 1091 |
|
|
@@ -2656,26 +2759,55 @@ def prune_fanin(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Conf
|
|
| 2656 |
)
|
| 2657 |
|
| 2658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2659 |
def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2660 |
"""
|
| 2661 |
-
Exhaustive search
|
| 2662 |
|
| 2663 |
-
|
| 2664 |
-
|
| 2665 |
-
globally optimal solutions but only feasible for small circuits.
|
| 2666 |
|
| 2667 |
-
|
| 2668 |
"""
|
| 2669 |
start = time.perf_counter()
|
| 2670 |
original = circuit.stats()
|
| 2671 |
|
| 2672 |
n_params = original['total']
|
| 2673 |
-
|
| 2674 |
|
| 2675 |
if n_params > cfg.exhaustive_max_params:
|
| 2676 |
if cfg.verbose:
|
| 2677 |
print(f" [EXHAUSTIVE] Skipping: {n_params} params exceeds max {cfg.exhaustive_max_params}")
|
| 2678 |
-
print(f" [EXHAUSTIVE] Search space would be {(2*search_range+1)**n_params:,} combinations")
|
| 2679 |
return PruneResult(
|
| 2680 |
method='exhaustive',
|
| 2681 |
original_stats=original,
|
|
@@ -2686,28 +2818,15 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
|
|
| 2686 |
metadata={'skipped': True, 'reason': 'too_many_params'}
|
| 2687 |
)
|
| 2688 |
|
| 2689 |
-
search_space = (2 * search_range + 1) ** n_params
|
| 2690 |
-
|
| 2691 |
if cfg.verbose:
|
| 2692 |
print(f" [EXHAUSTIVE] Parameters: {n_params}")
|
| 2693 |
-
print(f" [EXHAUSTIVE]
|
| 2694 |
-
print(f" [EXHAUSTIVE]
|
| 2695 |
|
| 2696 |
weight_keys = list(circuit.weights.keys())
|
| 2697 |
weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
|
| 2698 |
weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
|
| 2699 |
|
| 2700 |
-
best_weights = circuit.clone_weights()
|
| 2701 |
-
best_mag = original['magnitude']
|
| 2702 |
-
best_fitness = evaluator.evaluate_single(best_weights)
|
| 2703 |
-
|
| 2704 |
-
values = list(range(-search_range, search_range + 1))
|
| 2705 |
-
|
| 2706 |
-
tested = 0
|
| 2707 |
-
valid_found = 0
|
| 2708 |
-
report_interval = max(1, search_space // 100)
|
| 2709 |
-
last_report_time = start
|
| 2710 |
-
|
| 2711 |
def vector_to_weights(vec):
|
| 2712 |
weights = {}
|
| 2713 |
idx = 0
|
|
@@ -2717,50 +2836,92 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
|
|
| 2717 |
idx += size
|
| 2718 |
return weights
|
| 2719 |
|
| 2720 |
-
|
| 2721 |
-
|
| 2722 |
-
|
| 2723 |
|
| 2724 |
-
for
|
| 2725 |
-
|
|
|
|
|
|
|
| 2726 |
|
| 2727 |
-
|
| 2728 |
-
|
| 2729 |
|
| 2730 |
-
if
|
| 2731 |
-
|
| 2732 |
-
mag = sum(abs(v) for v in combo)
|
| 2733 |
|
| 2734 |
-
|
| 2735 |
-
best_mag = mag
|
| 2736 |
-
best_weights = {k: v.clone() for k, v in weights.items()}
|
| 2737 |
-
best_fitness = fitness
|
| 2738 |
|
| 2739 |
-
|
| 2740 |
-
|
| 2741 |
-
|
| 2742 |
-
|
| 2743 |
-
|
| 2744 |
-
|
| 2745 |
-
|
| 2746 |
-
|
| 2747 |
-
|
| 2748 |
-
|
| 2749 |
-
|
| 2750 |
-
|
| 2751 |
-
|
| 2752 |
-
|
| 2753 |
-
|
| 2754 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2755 |
|
| 2756 |
if cfg.verbose:
|
| 2757 |
-
elapsed = time.perf_counter() - start
|
| 2758 |
print(f" [EXHAUSTIVE COMPLETE]")
|
| 2759 |
-
print(f" -
|
| 2760 |
-
print(f" -
|
| 2761 |
-
print(f" -
|
| 2762 |
-
print(f" - Reduction: {(1 -
|
| 2763 |
-
print(f" - Time: {elapsed:.1f}s
|
| 2764 |
|
| 2765 |
return PruneResult(
|
| 2766 |
method='exhaustive',
|
|
@@ -2768,12 +2929,12 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
|
|
| 2768 |
final_stats=circuit.stats(best_weights),
|
| 2769 |
final_weights=best_weights,
|
| 2770 |
fitness=best_fitness,
|
| 2771 |
-
time_seconds=
|
| 2772 |
metadata={
|
| 2773 |
-
'
|
| 2774 |
-
'
|
| 2775 |
-
'
|
| 2776 |
-
'
|
| 2777 |
}
|
| 2778 |
)
|
| 2779 |
|
|
|
|
| 47 |
|
| 48 |
warnings.filterwarnings('ignore')
|
| 49 |
|
| 50 |
+
CIRCUITS_PATH = Path('D:/threshold-logic-circuits')
|
| 51 |
RESULTS_PATH = CIRCUITS_PATH / 'pruned_results'
|
| 52 |
|
| 53 |
|
|
|
|
| 534 |
model = cls(weights)
|
| 535 |
return list(model(int(inputs[0]), int(inputs[1])))
|
| 536 |
return wrapper
|
| 537 |
+
elif 'mux' in name and hasattr(module, 'mux'):
|
| 538 |
+
mux_fn = module.mux
|
| 539 |
+
def wrapper(inputs, weights):
|
| 540 |
+
return [mux_fn(int(inputs[0]), int(inputs[1]), int(inputs[2]), weights)]
|
| 541 |
+
return wrapper
|
| 542 |
+
elif 'demux' in name and hasattr(module, 'demux'):
|
| 543 |
+
demux_fn = module.demux
|
| 544 |
+
def wrapper(inputs, weights):
|
| 545 |
+
result = demux_fn(int(inputs[0]), int(inputs[1]), weights)
|
| 546 |
+
return list(result) if isinstance(result, (list, tuple)) else [result]
|
| 547 |
+
return wrapper
|
| 548 |
|
| 549 |
for attr_name in dir(module):
|
| 550 |
if attr_name.startswith('_'):
|
|
|
|
| 925 |
self.n_inputs = circuit.spec.inputs
|
| 926 |
self.n_outputs = circuit.spec.outputs
|
| 927 |
|
| 928 |
+
self.matrix_style = self._detect_matrix_style()
|
| 929 |
+
self.matrix_layout = None
|
| 930 |
+
if self.matrix_style:
|
| 931 |
+
self.matrix_layout = self._build_matrix_layout()
|
| 932 |
+
|
| 933 |
self.use_vmap = VMAP_AVAILABLE and cfg.device == 'cuda' and not circuit.has_native
|
| 934 |
self.vmap_forward = None
|
| 935 |
self.vmapped_fn = None
|
|
|
|
| 951 |
batched_status = "yes" if self.batched_ready and not self.use_native_eval else "no"
|
| 952 |
native_status = "forced" if self.use_native_eval else ("available" if circuit.has_native else "none")
|
| 953 |
seq_deps = "yes" if self.has_sequential_deps else "no"
|
| 954 |
+
matrix_status = "yes" if self.matrix_style else "no"
|
| 955 |
+
print(f" [EVAL] Evaluator ready: batch={self.max_batch:,}, vmap={vmap_status}, matrix={matrix_status}, batched={batched_status}, native={native_status}, seq_deps={seq_deps}")
|
| 956 |
+
|
| 957 |
+
def _detect_matrix_style(self) -> bool:
|
| 958 |
+
"""Detect if circuit uses matrix-style weights (layer1.weight with multiple rows)."""
|
| 959 |
+
weights = self.circuit.weights
|
| 960 |
+
for key, tensor in weights.items():
|
| 961 |
+
if 'layer1.weight' in key or key == 'layer1.weight':
|
| 962 |
+
if tensor.dim() == 2 and tensor.shape[0] > 1:
|
| 963 |
+
return True
|
| 964 |
+
return False
|
| 965 |
+
|
| 966 |
+
def _build_matrix_layout(self) -> dict:
|
| 967 |
+
"""Build layout info for matrix-style weight extraction.
|
| 968 |
+
MUST match the exact dict order used by circuit.weights_to_vector."""
|
| 969 |
+
weights = self.circuit.weights
|
| 970 |
+
weight_keys = list(weights.keys())
|
| 971 |
+
layout = {'layers': {}, 'total_params': 0, 'key_order': weight_keys}
|
| 972 |
+
|
| 973 |
+
idx = 0
|
| 974 |
+
for key in weight_keys:
|
| 975 |
+
tensor = weights[key]
|
| 976 |
+
size = tensor.numel()
|
| 977 |
+
|
| 978 |
+
if '.weight' in key or key.endswith('weight'):
|
| 979 |
+
layer_name = key.replace('.weight', '')
|
| 980 |
+
layout['layers'].setdefault(layer_name, {})
|
| 981 |
+
layout['layers'][layer_name]['w_start'] = idx
|
| 982 |
+
layout['layers'][layer_name]['w_end'] = idx + size
|
| 983 |
+
layout['layers'][layer_name]['w_shape'] = tuple(tensor.shape)
|
| 984 |
+
elif '.bias' in key or key.endswith('bias'):
|
| 985 |
+
layer_name = key.replace('.bias', '')
|
| 986 |
+
layout['layers'].setdefault(layer_name, {})
|
| 987 |
+
layout['layers'][layer_name]['b_start'] = idx
|
| 988 |
+
layout['layers'][layer_name]['b_end'] = idx + size
|
| 989 |
+
layout['layers'][layer_name]['b_shape'] = tuple(tensor.shape)
|
| 990 |
+
|
| 991 |
+
idx += size
|
| 992 |
+
|
| 993 |
+
layout['total_params'] = idx
|
| 994 |
+
layout['layer_order'] = sorted(layout['layers'].keys())
|
| 995 |
+
return layout
|
| 996 |
+
|
| 997 |
+
def _evaluate_matrix_vectorized(self, population: torch.Tensor) -> torch.Tensor:
|
| 998 |
+
"""
|
| 999 |
+
Fully vectorized evaluation for matrix-style weights.
|
| 1000 |
+
Handles circuits with layer1.weight [H, I], layer1.bias [H], layer2.weight [O, H], layer2.bias [O].
|
| 1001 |
+
"""
|
| 1002 |
+
pop_size = population.shape[0]
|
| 1003 |
+
device = population.device
|
| 1004 |
+
|
| 1005 |
+
layers = self.matrix_layout['layers']
|
| 1006 |
+
layer_order = self.matrix_layout['layer_order']
|
| 1007 |
+
|
| 1008 |
+
if len(layer_order) < 2:
|
| 1009 |
+
raise ValueError("Matrix-style eval requires at least 2 layers")
|
| 1010 |
+
|
| 1011 |
+
with torch.no_grad():
|
| 1012 |
+
l1_name = layer_order[0]
|
| 1013 |
+
l1 = layers[l1_name]
|
| 1014 |
+
l1_w = population[:, l1['w_start']:l1['w_end']].view(pop_size, *l1['w_shape'])
|
| 1015 |
+
l1_b = population[:, l1['b_start']:l1['b_end']].view(pop_size, *l1['b_shape'])
|
| 1016 |
+
|
| 1017 |
+
l2_name = layer_order[1]
|
| 1018 |
+
l2 = layers[l2_name]
|
| 1019 |
+
l2_w = population[:, l2['w_start']:l2['w_end']].view(pop_size, *l2['w_shape'])
|
| 1020 |
+
l2_b = population[:, l2['b_start']:l2['b_end']].view(pop_size, *l2['b_shape'])
|
| 1021 |
+
|
| 1022 |
+
inp = self.test_inputs
|
| 1023 |
+
|
| 1024 |
+
hidden = torch.einsum('ti,bhi->bth', inp, l1_w) + l1_b.unsqueeze(1)
|
| 1025 |
+
hidden = (hidden >= 0).float()
|
| 1026 |
+
|
| 1027 |
+
output = torch.einsum('bth,boh->bto', hidden, l2_w) + l2_b.unsqueeze(1)
|
| 1028 |
+
output = (output >= 0).float()
|
| 1029 |
+
|
| 1030 |
+
expected = self.test_expected.unsqueeze(0).expand(pop_size, -1, -1)
|
| 1031 |
+
correct = (output == expected).all(dim=-1).float().sum(dim=-1)
|
| 1032 |
+
fitness = correct / self.n_cases
|
| 1033 |
+
|
| 1034 |
+
return fitness
|
| 1035 |
|
| 1036 |
def _detect_sequential_dependencies(self) -> bool:
|
| 1037 |
"""Detect if circuit has neurons that depend on other neurons' outputs."""
|
|
|
|
| 1182 |
if pop_size > self.max_batch:
|
| 1183 |
return self._evaluate_chunked(population)
|
| 1184 |
|
| 1185 |
+
if self.matrix_style and self.matrix_layout:
|
| 1186 |
+
try:
|
| 1187 |
+
return self._evaluate_matrix_vectorized(population)
|
| 1188 |
+
except Exception as e:
|
| 1189 |
+
if self.cfg.verbose:
|
| 1190 |
+
print(f" Matrix vectorized eval failed ({e}), trying other methods")
|
| 1191 |
+
|
| 1192 |
if self.use_native_eval:
|
| 1193 |
return self._evaluate_native_parallel(population)
|
| 1194 |
|
|
|
|
| 2759 |
)
|
| 2760 |
|
| 2761 |
|
| 2762 |
+
def _partitions(total: int, n: int, max_val: int):
|
| 2763 |
+
"""Generate all ways to partition 'total' into 'n' non-negative integers <= max_val."""
|
| 2764 |
+
if n == 0:
|
| 2765 |
+
if total == 0:
|
| 2766 |
+
yield []
|
| 2767 |
+
return
|
| 2768 |
+
for i in range(min(total, max_val) + 1):
|
| 2769 |
+
for rest in _partitions(total - i, n - 1, max_val):
|
| 2770 |
+
yield [i] + rest
|
| 2771 |
+
|
| 2772 |
+
|
| 2773 |
+
def _all_signs(abs_vals: list):
|
| 2774 |
+
"""Generate all sign combinations for absolute values."""
|
| 2775 |
+
if not abs_vals:
|
| 2776 |
+
yield []
|
| 2777 |
+
return
|
| 2778 |
+
for rest in _all_signs(abs_vals[1:]):
|
| 2779 |
+
if abs_vals[0] == 0:
|
| 2780 |
+
yield [0] + rest
|
| 2781 |
+
else:
|
| 2782 |
+
yield [abs_vals[0]] + rest
|
| 2783 |
+
yield [-abs_vals[0]] + rest
|
| 2784 |
+
|
| 2785 |
+
|
| 2786 |
+
def _configs_at_magnitude(mag: int, n_params: int):
|
| 2787 |
+
"""Generate all n_params-length configs with given total magnitude."""
|
| 2788 |
+
for partition in _partitions(mag, n_params, mag):
|
| 2789 |
+
for signed in _all_signs(partition):
|
| 2790 |
+
yield tuple(signed)
|
| 2791 |
+
|
| 2792 |
+
|
| 2793 |
def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2794 |
"""
|
| 2795 |
+
Exhaustive search by magnitude level - finds provably optimal solutions.
|
| 2796 |
|
| 2797 |
+
Searches magnitude 0, then 1, then 2, ... until valid solutions found.
|
| 2798 |
+
Returns ALL valid solutions at the minimum magnitude (to discover families).
|
|
|
|
| 2799 |
|
| 2800 |
+
Much faster than arbitrary-order search since it stops at first valid magnitude.
|
| 2801 |
"""
|
| 2802 |
start = time.perf_counter()
|
| 2803 |
original = circuit.stats()
|
| 2804 |
|
| 2805 |
n_params = original['total']
|
| 2806 |
+
max_mag = int(original['magnitude'])
|
| 2807 |
|
| 2808 |
if n_params > cfg.exhaustive_max_params:
|
| 2809 |
if cfg.verbose:
|
| 2810 |
print(f" [EXHAUSTIVE] Skipping: {n_params} params exceeds max {cfg.exhaustive_max_params}")
|
|
|
|
| 2811 |
return PruneResult(
|
| 2812 |
method='exhaustive',
|
| 2813 |
original_stats=original,
|
|
|
|
| 2818 |
metadata={'skipped': True, 'reason': 'too_many_params'}
|
| 2819 |
)
|
| 2820 |
|
|
|
|
|
|
|
| 2821 |
if cfg.verbose:
|
| 2822 |
print(f" [EXHAUSTIVE] Parameters: {n_params}")
|
| 2823 |
+
print(f" [EXHAUSTIVE] Original magnitude: {max_mag}")
|
| 2824 |
+
print(f" [EXHAUSTIVE] Searching by magnitude level (0, 1, 2, ...)")
|
| 2825 |
|
| 2826 |
weight_keys = list(circuit.weights.keys())
|
| 2827 |
weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
|
| 2828 |
weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
|
| 2829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2830 |
def vector_to_weights(vec):
|
| 2831 |
weights = {}
|
| 2832 |
idx = 0
|
|
|
|
| 2836 |
idx += size
|
| 2837 |
return weights
|
| 2838 |
|
| 2839 |
+
total_tested = 0
|
| 2840 |
+
all_solutions = []
|
| 2841 |
+
optimal_mag = None
|
| 2842 |
|
| 2843 |
+
for mag in range(0, max_mag + 1):
|
| 2844 |
+
mag_start = time.perf_counter()
|
| 2845 |
+
configs = list(_configs_at_magnitude(mag, n_params))
|
| 2846 |
+
n_configs = len(configs)
|
| 2847 |
|
| 2848 |
+
if n_configs == 0:
|
| 2849 |
+
continue
|
| 2850 |
|
| 2851 |
+
if cfg.verbose:
|
| 2852 |
+
print(f" Magnitude {mag}: {n_configs:,} configurations...", end=" ", flush=True)
|
|
|
|
| 2853 |
|
| 2854 |
+
valid_at_mag = []
|
|
|
|
|
|
|
|
|
|
| 2855 |
|
| 2856 |
+
batch_size = min(100000, n_configs)
|
| 2857 |
+
for batch_start in range(0, n_configs, batch_size):
|
| 2858 |
+
batch_end = min(batch_start + batch_size, n_configs)
|
| 2859 |
+
batch_configs = configs[batch_start:batch_end]
|
| 2860 |
+
|
| 2861 |
+
population = torch.tensor(batch_configs, dtype=torch.float32, device=cfg.device)
|
| 2862 |
+
|
| 2863 |
+
try:
|
| 2864 |
+
fitness_batch = evaluator.evaluate_population(population)
|
| 2865 |
+
except:
|
| 2866 |
+
fitness_batch = torch.tensor([
|
| 2867 |
+
evaluator.evaluate_single(vector_to_weights(c))
|
| 2868 |
+
for c in batch_configs
|
| 2869 |
+
], device=cfg.device)
|
| 2870 |
+
|
| 2871 |
+
valid_mask = fitness_batch >= cfg.fitness_threshold
|
| 2872 |
+
|
| 2873 |
+
for i, is_valid in enumerate(valid_mask.tolist()):
|
| 2874 |
+
if is_valid:
|
| 2875 |
+
valid_at_mag.append(batch_configs[i])
|
| 2876 |
+
|
| 2877 |
+
total_tested += n_configs
|
| 2878 |
+
mag_time = time.perf_counter() - mag_start
|
| 2879 |
+
|
| 2880 |
+
if valid_at_mag:
|
| 2881 |
+
if cfg.verbose:
|
| 2882 |
+
print(f"FOUND {len(valid_at_mag)} solutions! ({mag_time:.2f}s)")
|
| 2883 |
+
|
| 2884 |
+
optimal_mag = mag
|
| 2885 |
+
all_solutions = valid_at_mag
|
| 2886 |
+
|
| 2887 |
+
if cfg.verbose:
|
| 2888 |
+
print(f" [EXHAUSTIVE] Optimal magnitude: {optimal_mag}")
|
| 2889 |
+
print(f" [EXHAUSTIVE] Solutions found: {len(all_solutions)}")
|
| 2890 |
+
print(f" [EXHAUSTIVE] Solution analysis:")
|
| 2891 |
+
print(f" {'#':<3} {'Mag':<5} {'NZ':<4} {'Max|w|':<7} {'Sparse%':<8} {'Weights'}")
|
| 2892 |
+
print(f" {'-'*60}")
|
| 2893 |
+
for i, sol in enumerate(all_solutions[:20]):
|
| 2894 |
+
mag = sum(abs(v) for v in sol)
|
| 2895 |
+
nz = sum(1 for v in sol if v != 0)
|
| 2896 |
+
max_w = max(abs(v) for v in sol)
|
| 2897 |
+
sparsity = 100 * (len(sol) - nz) / len(sol)
|
| 2898 |
+
print(f" {i+1:<3} {mag:<5} {nz:<4} {max_w:<7} {sparsity:<8.1f} {sol}")
|
| 2899 |
+
if len(all_solutions) > 20:
|
| 2900 |
+
print(f" ... and {len(all_solutions) - 20} more")
|
| 2901 |
+
|
| 2902 |
+
break
|
| 2903 |
+
else:
|
| 2904 |
+
if cfg.verbose:
|
| 2905 |
+
print(f"none ({mag_time:.2f}s)")
|
| 2906 |
+
|
| 2907 |
+
elapsed = time.perf_counter() - start
|
| 2908 |
+
|
| 2909 |
+
if all_solutions:
|
| 2910 |
+
best_combo = all_solutions[0]
|
| 2911 |
+
best_weights = vector_to_weights(best_combo)
|
| 2912 |
+
best_fitness = evaluator.evaluate_single(best_weights)
|
| 2913 |
+
else:
|
| 2914 |
+
best_weights = circuit.clone_weights()
|
| 2915 |
+
best_fitness = evaluator.evaluate_single(best_weights)
|
| 2916 |
+
optimal_mag = max_mag
|
| 2917 |
|
| 2918 |
if cfg.verbose:
|
|
|
|
| 2919 |
print(f" [EXHAUSTIVE COMPLETE]")
|
| 2920 |
+
print(f" - Configurations tested: {total_tested:,}")
|
| 2921 |
+
print(f" - Optimal magnitude: {optimal_mag} (original: {max_mag})")
|
| 2922 |
+
print(f" - Total solutions at optimal: {len(all_solutions)}")
|
| 2923 |
+
print(f" - Reduction: {(1 - optimal_mag/max_mag)*100:.1f}%")
|
| 2924 |
+
print(f" - Time: {elapsed:.1f}s")
|
| 2925 |
|
| 2926 |
return PruneResult(
|
| 2927 |
method='exhaustive',
|
|
|
|
| 2929 |
final_stats=circuit.stats(best_weights),
|
| 2930 |
final_weights=best_weights,
|
| 2931 |
fitness=best_fitness,
|
| 2932 |
+
time_seconds=elapsed,
|
| 2933 |
metadata={
|
| 2934 |
+
'optimal_magnitude': optimal_mag,
|
| 2935 |
+
'total_tested': total_tested,
|
| 2936 |
+
'solutions_count': len(all_solutions),
|
| 2937 |
+
'all_solutions': all_solutions[:100]
|
| 2938 |
}
|
| 2939 |
)
|
| 2940 |
|