CharlesCNorton commited on
Commit
43d198d
·
1 Parent(s): ff1ca32

Add magnitude-ordered exhaustive search with GPU-vectorized evaluation

Browse files

- Exhaustive search now searches by magnitude level (0, 1, 2, ...) and stops
at first valid magnitude, finding ALL solutions at that level
- Add matrix-style weight detection for circuits using layer1.weight [H,I] format
- Add GPU-vectorized evaluation for matrix-style circuits (massive speedup)
- Add MUX/DEMUX native forward function detection
- Report full metrics for solutions: magnitude, nonzero count, max|w|, sparsity
- Fix CIRCUITS_PATH to point to threshold-logic-circuits

Files changed (1) hide show
  1. prune.py +226 -65
prune.py CHANGED
@@ -47,7 +47,7 @@ except ImportError:
47
 
48
  warnings.filterwarnings('ignore')
49
 
50
- CIRCUITS_PATH = Path('D:/threshold-circuits')
51
  RESULTS_PATH = CIRCUITS_PATH / 'pruned_results'
52
 
53
 
@@ -534,6 +534,17 @@ class AdaptiveCircuit:
534
  model = cls(weights)
535
  return list(model(int(inputs[0]), int(inputs[1])))
536
  return wrapper
 
 
 
 
 
 
 
 
 
 
 
537
 
538
  for attr_name in dir(module):
539
  if attr_name.startswith('_'):
@@ -914,6 +925,11 @@ class BatchedEvaluator:
914
  self.n_inputs = circuit.spec.inputs
915
  self.n_outputs = circuit.spec.outputs
916
 
 
 
 
 
 
917
  self.use_vmap = VMAP_AVAILABLE and cfg.device == 'cuda' and not circuit.has_native
918
  self.vmap_forward = None
919
  self.vmapped_fn = None
@@ -935,7 +951,87 @@ class BatchedEvaluator:
935
  batched_status = "yes" if self.batched_ready and not self.use_native_eval else "no"
936
  native_status = "forced" if self.use_native_eval else ("available" if circuit.has_native else "none")
937
  seq_deps = "yes" if self.has_sequential_deps else "no"
938
- print(f" [EVAL] Evaluator ready: batch={self.max_batch:,}, vmap={vmap_status}, batched={batched_status}, native={native_status}, seq_deps={seq_deps}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939
 
940
  def _detect_sequential_dependencies(self) -> bool:
941
  """Detect if circuit has neurons that depend on other neurons' outputs."""
@@ -1086,6 +1182,13 @@ class BatchedEvaluator:
1086
  if pop_size > self.max_batch:
1087
  return self._evaluate_chunked(population)
1088
 
 
 
 
 
 
 
 
1089
  if self.use_native_eval:
1090
  return self._evaluate_native_parallel(population)
1091
 
@@ -2656,26 +2759,55 @@ def prune_fanin(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Conf
2656
  )
2657
 
2658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2659
  def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2660
  """
2661
- Exhaustive search over all integer weight combinations.
2662
 
2663
- Unlike evolutionary/annealing methods that perturb existing weights,
2664
- this searches ALL possible combinations within a range. Can find
2665
- globally optimal solutions but only feasible for small circuits.
2666
 
2667
- Complexity: O((2*range+1)^n_params) - exponential in parameter count.
2668
  """
2669
  start = time.perf_counter()
2670
  original = circuit.stats()
2671
 
2672
  n_params = original['total']
2673
- search_range = cfg.exhaustive_range
2674
 
2675
  if n_params > cfg.exhaustive_max_params:
2676
  if cfg.verbose:
2677
  print(f" [EXHAUSTIVE] Skipping: {n_params} params exceeds max {cfg.exhaustive_max_params}")
2678
- print(f" [EXHAUSTIVE] Search space would be {(2*search_range+1)**n_params:,} combinations")
2679
  return PruneResult(
2680
  method='exhaustive',
2681
  original_stats=original,
@@ -2686,28 +2818,15 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
2686
  metadata={'skipped': True, 'reason': 'too_many_params'}
2687
  )
2688
 
2689
- search_space = (2 * search_range + 1) ** n_params
2690
-
2691
  if cfg.verbose:
2692
  print(f" [EXHAUSTIVE] Parameters: {n_params}")
2693
- print(f" [EXHAUSTIVE] Range: [{-search_range}, {search_range}]")
2694
- print(f" [EXHAUSTIVE] Search space: {search_space:,} combinations")
2695
 
2696
  weight_keys = list(circuit.weights.keys())
2697
  weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
2698
  weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
2699
 
2700
- best_weights = circuit.clone_weights()
2701
- best_mag = original['magnitude']
2702
- best_fitness = evaluator.evaluate_single(best_weights)
2703
-
2704
- values = list(range(-search_range, search_range + 1))
2705
-
2706
- tested = 0
2707
- valid_found = 0
2708
- report_interval = max(1, search_space // 100)
2709
- last_report_time = start
2710
-
2711
  def vector_to_weights(vec):
2712
  weights = {}
2713
  idx = 0
@@ -2717,50 +2836,92 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
2717
  idx += size
2718
  return weights
2719
 
2720
- if cfg.verbose:
2721
- print(f" [EXHAUSTIVE] Starting search...")
2722
- print(f" [EXHAUSTIVE] Progress updates every 1%")
2723
 
2724
- for combo in product(values, repeat=n_params):
2725
- tested += 1
 
 
2726
 
2727
- weights = vector_to_weights(combo)
2728
- fitness = evaluator.evaluate_single(weights)
2729
 
2730
- if fitness >= cfg.fitness_threshold:
2731
- valid_found += 1
2732
- mag = sum(abs(v) for v in combo)
2733
 
2734
- if mag < best_mag:
2735
- best_mag = mag
2736
- best_weights = {k: v.clone() for k, v in weights.items()}
2737
- best_fitness = fitness
2738
 
2739
- if cfg.verbose:
2740
- elapsed = time.perf_counter() - start
2741
- print(f" [{elapsed:6.1f}s] NEW BEST: magnitude={mag}, weights={combo}")
2742
-
2743
- if cfg.verbose and tested % report_interval == 0:
2744
- now = time.perf_counter()
2745
- elapsed = now - start
2746
- interval_time = now - last_report_time
2747
- rate = report_interval / interval_time if interval_time > 0 else 0
2748
- overall_rate = tested / elapsed if elapsed > 0 else 0
2749
- eta = (search_space - tested) / overall_rate if overall_rate > 0 else 0
2750
- pct = 100 * tested / search_space
2751
- print(f" [{elapsed:6.1f}s] {pct:5.1f}% | {tested:,}/{search_space:,} | "
2752
- f"valid: {valid_found:,} | best: {best_mag:.0f} | "
2753
- f"{rate:,.0f}/s (avg {overall_rate:,.0f}/s) | ETA: {eta:.0f}s")
2754
- last_report_time = now
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2755
 
2756
  if cfg.verbose:
2757
- elapsed = time.perf_counter() - start
2758
  print(f" [EXHAUSTIVE COMPLETE]")
2759
- print(f" - Combinations tested: {tested:,}")
2760
- print(f" - Valid solutions found: {valid_found:,}")
2761
- print(f" - Best magnitude: {best_mag:.0f} (original: {original['magnitude']:.0f})")
2762
- print(f" - Reduction: {(1 - best_mag/original['magnitude'])*100:.1f}%")
2763
- print(f" - Time: {elapsed:.1f}s ({tested/elapsed:.0f} combos/s)")
2764
 
2765
  return PruneResult(
2766
  method='exhaustive',
@@ -2768,12 +2929,12 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
2768
  final_stats=circuit.stats(best_weights),
2769
  final_weights=best_weights,
2770
  fitness=best_fitness,
2771
- time_seconds=time.perf_counter() - start,
2772
  metadata={
2773
- 'search_space': search_space,
2774
- 'tested': tested,
2775
- 'valid_found': valid_found,
2776
- 'search_range': search_range
2777
  }
2778
  )
2779
 
 
47
 
48
  warnings.filterwarnings('ignore')
49
 
50
+ CIRCUITS_PATH = Path('D:/threshold-logic-circuits')
51
  RESULTS_PATH = CIRCUITS_PATH / 'pruned_results'
52
 
53
 
 
534
  model = cls(weights)
535
  return list(model(int(inputs[0]), int(inputs[1])))
536
  return wrapper
537
+ elif 'mux' in name and hasattr(module, 'mux'):
538
+ mux_fn = module.mux
539
+ def wrapper(inputs, weights):
540
+ return [mux_fn(int(inputs[0]), int(inputs[1]), int(inputs[2]), weights)]
541
+ return wrapper
542
+ elif 'demux' in name and hasattr(module, 'demux'):
543
+ demux_fn = module.demux
544
+ def wrapper(inputs, weights):
545
+ result = demux_fn(int(inputs[0]), int(inputs[1]), weights)
546
+ return list(result) if isinstance(result, (list, tuple)) else [result]
547
+ return wrapper
548
 
549
  for attr_name in dir(module):
550
  if attr_name.startswith('_'):
 
925
  self.n_inputs = circuit.spec.inputs
926
  self.n_outputs = circuit.spec.outputs
927
 
928
+ self.matrix_style = self._detect_matrix_style()
929
+ self.matrix_layout = None
930
+ if self.matrix_style:
931
+ self.matrix_layout = self._build_matrix_layout()
932
+
933
  self.use_vmap = VMAP_AVAILABLE and cfg.device == 'cuda' and not circuit.has_native
934
  self.vmap_forward = None
935
  self.vmapped_fn = None
 
951
  batched_status = "yes" if self.batched_ready and not self.use_native_eval else "no"
952
  native_status = "forced" if self.use_native_eval else ("available" if circuit.has_native else "none")
953
  seq_deps = "yes" if self.has_sequential_deps else "no"
954
+ matrix_status = "yes" if self.matrix_style else "no"
955
+ print(f" [EVAL] Evaluator ready: batch={self.max_batch:,}, vmap={vmap_status}, matrix={matrix_status}, batched={batched_status}, native={native_status}, seq_deps={seq_deps}")
956
+
957
+ def _detect_matrix_style(self) -> bool:
958
+ """Detect if circuit uses matrix-style weights (layer1.weight with multiple rows)."""
959
+ weights = self.circuit.weights
960
+ for key, tensor in weights.items():
961
+ if 'layer1.weight' in key or key == 'layer1.weight':
962
+ if tensor.dim() == 2 and tensor.shape[0] > 1:
963
+ return True
964
+ return False
965
+
966
+ def _build_matrix_layout(self) -> dict:
967
+ """Build layout info for matrix-style weight extraction.
968
+ MUST match the exact dict order used by circuit.weights_to_vector."""
969
+ weights = self.circuit.weights
970
+ weight_keys = list(weights.keys())
971
+ layout = {'layers': {}, 'total_params': 0, 'key_order': weight_keys}
972
+
973
+ idx = 0
974
+ for key in weight_keys:
975
+ tensor = weights[key]
976
+ size = tensor.numel()
977
+
978
+ if '.weight' in key or key.endswith('weight'):
979
+ layer_name = key.replace('.weight', '')
980
+ layout['layers'].setdefault(layer_name, {})
981
+ layout['layers'][layer_name]['w_start'] = idx
982
+ layout['layers'][layer_name]['w_end'] = idx + size
983
+ layout['layers'][layer_name]['w_shape'] = tuple(tensor.shape)
984
+ elif '.bias' in key or key.endswith('bias'):
985
+ layer_name = key.replace('.bias', '')
986
+ layout['layers'].setdefault(layer_name, {})
987
+ layout['layers'][layer_name]['b_start'] = idx
988
+ layout['layers'][layer_name]['b_end'] = idx + size
989
+ layout['layers'][layer_name]['b_shape'] = tuple(tensor.shape)
990
+
991
+ idx += size
992
+
993
+ layout['total_params'] = idx
994
+ layout['layer_order'] = sorted(layout['layers'].keys())
995
+ return layout
996
+
997
+ def _evaluate_matrix_vectorized(self, population: torch.Tensor) -> torch.Tensor:
998
+ """
999
+ Fully vectorized evaluation for matrix-style weights.
1000
+ Handles circuits with layer1.weight [H, I], layer1.bias [H], layer2.weight [O, H], layer2.bias [O].
1001
+ """
1002
+ pop_size = population.shape[0]
1003
+ device = population.device
1004
+
1005
+ layers = self.matrix_layout['layers']
1006
+ layer_order = self.matrix_layout['layer_order']
1007
+
1008
+ if len(layer_order) < 2:
1009
+ raise ValueError("Matrix-style eval requires at least 2 layers")
1010
+
1011
+ with torch.no_grad():
1012
+ l1_name = layer_order[0]
1013
+ l1 = layers[l1_name]
1014
+ l1_w = population[:, l1['w_start']:l1['w_end']].view(pop_size, *l1['w_shape'])
1015
+ l1_b = population[:, l1['b_start']:l1['b_end']].view(pop_size, *l1['b_shape'])
1016
+
1017
+ l2_name = layer_order[1]
1018
+ l2 = layers[l2_name]
1019
+ l2_w = population[:, l2['w_start']:l2['w_end']].view(pop_size, *l2['w_shape'])
1020
+ l2_b = population[:, l2['b_start']:l2['b_end']].view(pop_size, *l2['b_shape'])
1021
+
1022
+ inp = self.test_inputs
1023
+
1024
+ hidden = torch.einsum('ti,bhi->bth', inp, l1_w) + l1_b.unsqueeze(1)
1025
+ hidden = (hidden >= 0).float()
1026
+
1027
+ output = torch.einsum('bth,boh->bto', hidden, l2_w) + l2_b.unsqueeze(1)
1028
+ output = (output >= 0).float()
1029
+
1030
+ expected = self.test_expected.unsqueeze(0).expand(pop_size, -1, -1)
1031
+ correct = (output == expected).all(dim=-1).float().sum(dim=-1)
1032
+ fitness = correct / self.n_cases
1033
+
1034
+ return fitness
1035
 
1036
  def _detect_sequential_dependencies(self) -> bool:
1037
  """Detect if circuit has neurons that depend on other neurons' outputs."""
 
1182
  if pop_size > self.max_batch:
1183
  return self._evaluate_chunked(population)
1184
 
1185
+ if self.matrix_style and self.matrix_layout:
1186
+ try:
1187
+ return self._evaluate_matrix_vectorized(population)
1188
+ except Exception as e:
1189
+ if self.cfg.verbose:
1190
+ print(f" Matrix vectorized eval failed ({e}), trying other methods")
1191
+
1192
  if self.use_native_eval:
1193
  return self._evaluate_native_parallel(population)
1194
 
 
2759
  )
2760
 
2761
 
2762
+ def _partitions(total: int, n: int, max_val: int):
2763
+ """Generate all ways to partition 'total' into 'n' non-negative integers <= max_val."""
2764
+ if n == 0:
2765
+ if total == 0:
2766
+ yield []
2767
+ return
2768
+ for i in range(min(total, max_val) + 1):
2769
+ for rest in _partitions(total - i, n - 1, max_val):
2770
+ yield [i] + rest
2771
+
2772
+
2773
+ def _all_signs(abs_vals: list):
2774
+ """Generate all sign combinations for absolute values."""
2775
+ if not abs_vals:
2776
+ yield []
2777
+ return
2778
+ for rest in _all_signs(abs_vals[1:]):
2779
+ if abs_vals[0] == 0:
2780
+ yield [0] + rest
2781
+ else:
2782
+ yield [abs_vals[0]] + rest
2783
+ yield [-abs_vals[0]] + rest
2784
+
2785
+
2786
+ def _configs_at_magnitude(mag: int, n_params: int):
2787
+ """Generate all n_params-length configs with given total magnitude."""
2788
+ for partition in _partitions(mag, n_params, mag):
2789
+ for signed in _all_signs(partition):
2790
+ yield tuple(signed)
2791
+
2792
+
2793
  def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2794
  """
2795
+ Exhaustive search by magnitude level - finds provably optimal solutions.
2796
 
2797
+ Searches magnitude 0, then 1, then 2, ... until valid solutions found.
2798
+ Returns ALL valid solutions at the minimum magnitude (to discover families).
 
2799
 
2800
+ Much faster than arbitrary-order search since it stops at first valid magnitude.
2801
  """
2802
  start = time.perf_counter()
2803
  original = circuit.stats()
2804
 
2805
  n_params = original['total']
2806
+ max_mag = int(original['magnitude'])
2807
 
2808
  if n_params > cfg.exhaustive_max_params:
2809
  if cfg.verbose:
2810
  print(f" [EXHAUSTIVE] Skipping: {n_params} params exceeds max {cfg.exhaustive_max_params}")
 
2811
  return PruneResult(
2812
  method='exhaustive',
2813
  original_stats=original,
 
2818
  metadata={'skipped': True, 'reason': 'too_many_params'}
2819
  )
2820
 
 
 
2821
  if cfg.verbose:
2822
  print(f" [EXHAUSTIVE] Parameters: {n_params}")
2823
+ print(f" [EXHAUSTIVE] Original magnitude: {max_mag}")
2824
+ print(f" [EXHAUSTIVE] Searching by magnitude level (0, 1, 2, ...)")
2825
 
2826
  weight_keys = list(circuit.weights.keys())
2827
  weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
2828
  weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
2829
 
 
 
 
 
 
 
 
 
 
 
 
2830
  def vector_to_weights(vec):
2831
  weights = {}
2832
  idx = 0
 
2836
  idx += size
2837
  return weights
2838
 
2839
+ total_tested = 0
2840
+ all_solutions = []
2841
+ optimal_mag = None
2842
 
2843
+ for mag in range(0, max_mag + 1):
2844
+ mag_start = time.perf_counter()
2845
+ configs = list(_configs_at_magnitude(mag, n_params))
2846
+ n_configs = len(configs)
2847
 
2848
+ if n_configs == 0:
2849
+ continue
2850
 
2851
+ if cfg.verbose:
2852
+ print(f" Magnitude {mag}: {n_configs:,} configurations...", end=" ", flush=True)
 
2853
 
2854
+ valid_at_mag = []
 
 
 
2855
 
2856
+ batch_size = min(100000, n_configs)
2857
+ for batch_start in range(0, n_configs, batch_size):
2858
+ batch_end = min(batch_start + batch_size, n_configs)
2859
+ batch_configs = configs[batch_start:batch_end]
2860
+
2861
+ population = torch.tensor(batch_configs, dtype=torch.float32, device=cfg.device)
2862
+
2863
+ try:
2864
+ fitness_batch = evaluator.evaluate_population(population)
2865
+ except:
2866
+ fitness_batch = torch.tensor([
2867
+ evaluator.evaluate_single(vector_to_weights(c))
2868
+ for c in batch_configs
2869
+ ], device=cfg.device)
2870
+
2871
+ valid_mask = fitness_batch >= cfg.fitness_threshold
2872
+
2873
+ for i, is_valid in enumerate(valid_mask.tolist()):
2874
+ if is_valid:
2875
+ valid_at_mag.append(batch_configs[i])
2876
+
2877
+ total_tested += n_configs
2878
+ mag_time = time.perf_counter() - mag_start
2879
+
2880
+ if valid_at_mag:
2881
+ if cfg.verbose:
2882
+ print(f"FOUND {len(valid_at_mag)} solutions! ({mag_time:.2f}s)")
2883
+
2884
+ optimal_mag = mag
2885
+ all_solutions = valid_at_mag
2886
+
2887
+ if cfg.verbose:
2888
+ print(f" [EXHAUSTIVE] Optimal magnitude: {optimal_mag}")
2889
+ print(f" [EXHAUSTIVE] Solutions found: {len(all_solutions)}")
2890
+ print(f" [EXHAUSTIVE] Solution analysis:")
2891
+ print(f" {'#':<3} {'Mag':<5} {'NZ':<4} {'Max|w|':<7} {'Sparse%':<8} {'Weights'}")
2892
+ print(f" {'-'*60}")
2893
+ for i, sol in enumerate(all_solutions[:20]):
2894
+ mag = sum(abs(v) for v in sol)
2895
+ nz = sum(1 for v in sol if v != 0)
2896
+ max_w = max(abs(v) for v in sol)
2897
+ sparsity = 100 * (len(sol) - nz) / len(sol)
2898
+ print(f" {i+1:<3} {mag:<5} {nz:<4} {max_w:<7} {sparsity:<8.1f} {sol}")
2899
+ if len(all_solutions) > 20:
2900
+ print(f" ... and {len(all_solutions) - 20} more")
2901
+
2902
+ break
2903
+ else:
2904
+ if cfg.verbose:
2905
+ print(f"none ({mag_time:.2f}s)")
2906
+
2907
+ elapsed = time.perf_counter() - start
2908
+
2909
+ if all_solutions:
2910
+ best_combo = all_solutions[0]
2911
+ best_weights = vector_to_weights(best_combo)
2912
+ best_fitness = evaluator.evaluate_single(best_weights)
2913
+ else:
2914
+ best_weights = circuit.clone_weights()
2915
+ best_fitness = evaluator.evaluate_single(best_weights)
2916
+ optimal_mag = max_mag
2917
 
2918
  if cfg.verbose:
 
2919
  print(f" [EXHAUSTIVE COMPLETE]")
2920
+ print(f" - Configurations tested: {total_tested:,}")
2921
+ print(f" - Optimal magnitude: {optimal_mag} (original: {max_mag})")
2922
+ print(f" - Total solutions at optimal: {len(all_solutions)}")
2923
+ print(f" - Reduction: {(1 - optimal_mag/max_mag)*100:.1f}%")
2924
+ print(f" - Time: {elapsed:.1f}s")
2925
 
2926
  return PruneResult(
2927
  method='exhaustive',
 
2929
  final_stats=circuit.stats(best_weights),
2930
  final_weights=best_weights,
2931
  fitness=best_fitness,
2932
+ time_seconds=elapsed,
2933
  metadata={
2934
+ 'optimal_magnitude': optimal_mag,
2935
+ 'total_tested': total_tested,
2936
+ 'solutions_count': len(all_solutions),
2937
+ 'all_solutions': all_solutions[:100]
2938
  }
2939
  )
2940