CharlesCNorton commited on
Commit
ff1ca32
·
1 Parent(s): 0780584

Add exhaustive search methods for provable optimality

Browse files

- Add prune_exhaustive: searches ALL weight combinations in range
to find globally optimal magnitude (not just local improvements)

- Rewrite prune_neuron to be truly exhaustive: for each neuron,
searches all weight configs for remaining neurons to prove
whether neuron is necessary (not just greedy zeroing)

- Add config options: exhaustive_range, exhaustive_max_params

- Add method aliases: 'exh', 'brute' for exhaustive search

These methods enable provable optimality for small circuits where
evolutionary/annealing methods get stuck in local minima.

Files changed (1) hide show
  1. prune.py +230 -21
prune.py CHANGED
@@ -138,9 +138,13 @@ class Config:
138
  run_gate_subst: bool = True
139
  run_symmetry: bool = True
140
  run_fanin: bool = True
 
141
 
142
  magnitude_passes: int = 100
143
 
 
 
 
144
  evo_generations: int = 2000
145
  evo_pop_size: int = 0
146
  evo_elite_ratio: float = 0.05
@@ -1730,44 +1734,125 @@ def prune_annealing(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
1730
 
1731
 
1732
  def prune_neuron(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
1733
- """Neuron-level pruning."""
 
 
 
 
 
 
1734
  start = time.perf_counter()
1735
- weights = circuit.clone_weights()
1736
- original = circuit.stats(weights)
1737
 
1738
  neuron_groups = defaultdict(list)
1739
- for key in weights.keys():
1740
  parts = key.rsplit('.', 1)
1741
  neuron_name = parts[0] if len(parts) == 2 else key.split('.')[0]
1742
  neuron_groups[neuron_name].append(key)
1743
 
 
 
 
 
 
 
1744
  if cfg.verbose:
1745
  print(f" Found {len(neuron_groups)} neuron groups")
 
1746
 
1747
- removed = 0
1748
- for neuron_name, keys in neuron_groups.items():
1749
- saved = {k: weights[k].clone() for k in keys if k in weights}
1750
 
1751
- for k in keys:
1752
- if k in weights:
1753
- weights[k] = torch.zeros_like(weights[k])
 
 
 
 
 
1754
 
1755
- if evaluator.evaluate_single(weights) >= cfg.fitness_threshold:
1756
- removed += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1757
  if cfg.verbose:
1758
- print(f" Removed: {neuron_name}")
 
 
 
1759
  else:
1760
- for k, v in saved.items():
1761
- weights[k] = v
 
 
 
 
 
 
1762
 
1763
  return PruneResult(
1764
  method='neuron',
1765
  original_stats=original,
1766
- final_stats=circuit.stats(weights),
1767
- final_weights=weights,
1768
- fitness=evaluator.evaluate_single(weights),
1769
  time_seconds=time.perf_counter() - start,
1770
- metadata={'neurons_removed': removed}
1771
  )
1772
 
1773
 
@@ -2571,6 +2656,128 @@ def prune_fanin(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Conf
2571
  )
2572
 
2573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2574
  def export_coq(circuit: AdaptiveCircuit, weights: Dict[str, torch.Tensor], cfg: Config) -> Path:
2575
  """
2576
  Export threshold circuit to Coq for formal verification.
@@ -2682,6 +2889,7 @@ def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneRes
2682
  ('gate_subst', cfg.run_gate_subst, lambda: prune_gate_substitution(circuit, evaluator, cfg)),
2683
  ('symmetry', cfg.run_symmetry, lambda: prune_symmetry(circuit, evaluator, cfg)),
2684
  ('fanin', cfg.run_fanin, lambda: prune_fanin(circuit, evaluator, cfg)),
 
2685
  ('random', cfg.run_random, lambda: prune_random(circuit, evaluator, cfg)),
2686
  ('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
2687
  ('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
@@ -2890,7 +3098,7 @@ def main():
2890
  all_methods = ['magnitude', 'zero', 'quantize', 'evolutionary', 'annealing',
2891
  'neuron', 'lottery', 'topology', 'structured', 'sensitivity',
2892
  'weight_sharing', 'random', 'pareto', 'depth', 'gate_subst',
2893
- 'symmetry', 'fanin']
2894
  for m in all_methods:
2895
  setattr(cfg, f'run_{m}', False)
2896
 
@@ -2913,7 +3121,8 @@ def main():
2913
  'depth': 'depth',
2914
  'gate': 'gate_subst', 'gate_subst': 'gate_subst', 'subst': 'gate_subst',
2915
  'sym': 'symmetry', 'symmetry': 'symmetry',
2916
- 'fanin': 'fanin', 'fan': 'fanin'
 
2917
  }
2918
  if m in method_map:
2919
  setattr(cfg, f'run_{method_map[m]}', True)
 
138
  run_gate_subst: bool = True
139
  run_symmetry: bool = True
140
  run_fanin: bool = True
141
+ run_exhaustive: bool = True
142
 
143
  magnitude_passes: int = 100
144
 
145
+ exhaustive_range: int = 2
146
+ exhaustive_max_params: int = 12
147
+
148
  evo_generations: int = 2000
149
  evo_pop_size: int = 0
150
  evo_elite_ratio: float = 0.05
 
1734
 
1735
 
1736
  def prune_neuron(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
1737
+ """
1738
+ Truly exhaustive neuron-level pruning.
1739
+
1740
+ For each neuron, fix its parameters to zero and exhaustively search
1741
+ all weight configurations for the REMAINING neurons. If any config
1742
+ works, the neuron is provably unnecessary.
1743
+ """
1744
  start = time.perf_counter()
1745
+ original = circuit.stats()
 
1746
 
1747
  neuron_groups = defaultdict(list)
1748
+ for key in circuit.weights.keys():
1749
  parts = key.rsplit('.', 1)
1750
  neuron_name = parts[0] if len(parts) == 2 else key.split('.')[0]
1751
  neuron_groups[neuron_name].append(key)
1752
 
1753
+ neuron_names = list(neuron_groups.keys())
1754
+ weight_keys = list(circuit.weights.keys())
1755
+ weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
1756
+ weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
1757
+ n_total_params = sum(weight_sizes.values())
1758
+
1759
  if cfg.verbose:
1760
  print(f" Found {len(neuron_groups)} neuron groups")
1761
+ print(f" Total parameters: {n_total_params}")
1762
 
1763
+ search_range = cfg.exhaustive_range
1764
+ values = list(range(-search_range, search_range + 1))
 
1765
 
1766
+ def vector_to_weights(vec):
1767
+ weights = {}
1768
+ idx = 0
1769
+ for k in weight_keys:
1770
+ size = weight_sizes[k]
1771
+ weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k])
1772
+ idx += size
1773
+ return weights
1774
 
1775
+ def get_neuron_param_indices(neuron_name):
1776
+ indices = []
1777
+ idx = 0
1778
+ for k in weight_keys:
1779
+ size = weight_sizes[k]
1780
+ if k in neuron_groups[neuron_name]:
1781
+ indices.extend(range(idx, idx + size))
1782
+ idx += size
1783
+ return indices
1784
+
1785
+ best_weights = circuit.clone_weights()
1786
+ best_neurons_removed = 0
1787
+ removed_neuron_names = []
1788
+
1789
+ for neuron_name in neuron_names:
1790
+ neuron_indices = set(get_neuron_param_indices(neuron_name))
1791
+ other_indices = [i for i in range(n_total_params) if i not in neuron_indices]
1792
+ n_other_params = len(other_indices)
1793
+
1794
+ if n_other_params > cfg.exhaustive_max_params:
1795
+ if cfg.verbose:
1796
+ print(f" [{neuron_name}] Skipping: {n_other_params} remaining params > max {cfg.exhaustive_max_params}")
1797
+ continue
1798
+
1799
+ search_space = (2 * search_range + 1) ** n_other_params
1800
+
1801
+ if cfg.verbose:
1802
+ print(f" [{neuron_name}] Testing removal: searching {search_space:,} configs for {n_other_params} remaining params")
1803
+
1804
+ found_valid = False
1805
+ best_config = None
1806
+ best_mag = float('inf')
1807
+ tested = 0
1808
+ report_interval = max(1, search_space // 10)
1809
+
1810
+ for combo in product(values, repeat=n_other_params):
1811
+ tested += 1
1812
+
1813
+ full_vec = [0] * n_total_params
1814
+ for i, val in zip(other_indices, combo):
1815
+ full_vec[i] = val
1816
+
1817
+ weights = vector_to_weights(full_vec)
1818
+ fitness = evaluator.evaluate_single(weights)
1819
+
1820
+ if fitness >= cfg.fitness_threshold:
1821
+ mag = sum(abs(v) for v in combo)
1822
+ if not found_valid or mag < best_mag:
1823
+ found_valid = True
1824
+ best_mag = mag
1825
+ best_config = full_vec
1826
+
1827
+ if cfg.verbose and tested % report_interval == 0:
1828
+ elapsed = time.perf_counter() - start
1829
+ pct = 100 * tested / search_space
1830
+ print(f" [{elapsed:6.1f}s] {pct:5.1f}% | valid={'YES' if found_valid else 'no '} | best_mag={best_mag if found_valid else '-'}")
1831
+
1832
+ if found_valid:
1833
  if cfg.verbose:
1834
+ print(f" [REMOVABLE] {neuron_name} can be removed! Best remaining mag={best_mag}")
1835
+ removed_neuron_names.append(neuron_name)
1836
+ best_neurons_removed += 1
1837
+ best_weights = vector_to_weights(best_config)
1838
  else:
1839
+ if cfg.verbose:
1840
+ print(f" [REQUIRED] {neuron_name} is necessary")
1841
+
1842
+ if cfg.verbose:
1843
+ elapsed = time.perf_counter() - start
1844
+ print(f" [NEURON COMPLETE] {best_neurons_removed} neurons removable out of {len(neuron_names)}")
1845
+ print(f" Removable: {removed_neuron_names if removed_neuron_names else 'none'}")
1846
+ print(f" Time: {elapsed:.1f}s")
1847
 
1848
  return PruneResult(
1849
  method='neuron',
1850
  original_stats=original,
1851
+ final_stats=circuit.stats(best_weights),
1852
+ final_weights=best_weights,
1853
+ fitness=evaluator.evaluate_single(best_weights),
1854
  time_seconds=time.perf_counter() - start,
1855
+ metadata={'neurons_removed': best_neurons_removed, 'removed_names': removed_neuron_names}
1856
  )
1857
 
1858
 
 
2656
  )
2657
 
2658
 
2659
+ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2660
+ """
2661
+ Exhaustive search over all integer weight combinations.
2662
+
2663
+ Unlike evolutionary/annealing methods that perturb existing weights,
2664
+ this searches ALL possible combinations within a range. Can find
2665
+ globally optimal solutions but only feasible for small circuits.
2666
+
2667
+ Complexity: O((2*range+1)^n_params) - exponential in parameter count.
2668
+ """
2669
+ start = time.perf_counter()
2670
+ original = circuit.stats()
2671
+
2672
+ n_params = original['total']
2673
+ search_range = cfg.exhaustive_range
2674
+
2675
+ if n_params > cfg.exhaustive_max_params:
2676
+ if cfg.verbose:
2677
+ print(f" [EXHAUSTIVE] Skipping: {n_params} params exceeds max {cfg.exhaustive_max_params}")
2678
+ print(f" [EXHAUSTIVE] Search space would be {(2*search_range+1)**n_params:,} combinations")
2679
+ return PruneResult(
2680
+ method='exhaustive',
2681
+ original_stats=original,
2682
+ final_stats=original,
2683
+ final_weights=circuit.clone_weights(),
2684
+ fitness=evaluator.evaluate_single(circuit.weights),
2685
+ time_seconds=time.perf_counter() - start,
2686
+ metadata={'skipped': True, 'reason': 'too_many_params'}
2687
+ )
2688
+
2689
+ search_space = (2 * search_range + 1) ** n_params
2690
+
2691
+ if cfg.verbose:
2692
+ print(f" [EXHAUSTIVE] Parameters: {n_params}")
2693
+ print(f" [EXHAUSTIVE] Range: [{-search_range}, {search_range}]")
2694
+ print(f" [EXHAUSTIVE] Search space: {search_space:,} combinations")
2695
+
2696
+ weight_keys = list(circuit.weights.keys())
2697
+ weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
2698
+ weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
2699
+
2700
+ best_weights = circuit.clone_weights()
2701
+ best_mag = original['magnitude']
2702
+ best_fitness = evaluator.evaluate_single(best_weights)
2703
+
2704
+ values = list(range(-search_range, search_range + 1))
2705
+
2706
+ tested = 0
2707
+ valid_found = 0
2708
+ report_interval = max(1, search_space // 100)
2709
+ last_report_time = start
2710
+
2711
+ def vector_to_weights(vec):
2712
+ weights = {}
2713
+ idx = 0
2714
+ for k in weight_keys:
2715
+ size = weight_sizes[k]
2716
+ weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k])
2717
+ idx += size
2718
+ return weights
2719
+
2720
+ if cfg.verbose:
2721
+ print(f" [EXHAUSTIVE] Starting search...")
2722
+ print(f" [EXHAUSTIVE] Progress updates every 1%")
2723
+
2724
+ for combo in product(values, repeat=n_params):
2725
+ tested += 1
2726
+
2727
+ weights = vector_to_weights(combo)
2728
+ fitness = evaluator.evaluate_single(weights)
2729
+
2730
+ if fitness >= cfg.fitness_threshold:
2731
+ valid_found += 1
2732
+ mag = sum(abs(v) for v in combo)
2733
+
2734
+ if mag < best_mag:
2735
+ best_mag = mag
2736
+ best_weights = {k: v.clone() for k, v in weights.items()}
2737
+ best_fitness = fitness
2738
+
2739
+ if cfg.verbose:
2740
+ elapsed = time.perf_counter() - start
2741
+ print(f" [{elapsed:6.1f}s] NEW BEST: magnitude={mag}, weights={combo}")
2742
+
2743
+ if cfg.verbose and tested % report_interval == 0:
2744
+ now = time.perf_counter()
2745
+ elapsed = now - start
2746
+ interval_time = now - last_report_time
2747
+ rate = report_interval / interval_time if interval_time > 0 else 0
2748
+ overall_rate = tested / elapsed if elapsed > 0 else 0
2749
+ eta = (search_space - tested) / overall_rate if overall_rate > 0 else 0
2750
+ pct = 100 * tested / search_space
2751
+ print(f" [{elapsed:6.1f}s] {pct:5.1f}% | {tested:,}/{search_space:,} | "
2752
+ f"valid: {valid_found:,} | best: {best_mag:.0f} | "
2753
+ f"{rate:,.0f}/s (avg {overall_rate:,.0f}/s) | ETA: {eta:.0f}s")
2754
+ last_report_time = now
2755
+
2756
+ if cfg.verbose:
2757
+ elapsed = time.perf_counter() - start
2758
+ print(f" [EXHAUSTIVE COMPLETE]")
2759
+ print(f" - Combinations tested: {tested:,}")
2760
+ print(f" - Valid solutions found: {valid_found:,}")
2761
+ print(f" - Best magnitude: {best_mag:.0f} (original: {original['magnitude']:.0f})")
2762
+ print(f" - Reduction: {(1 - best_mag/original['magnitude'])*100:.1f}%")
2763
+ print(f" - Time: {elapsed:.1f}s ({tested/elapsed:.0f} combos/s)")
2764
+
2765
+ return PruneResult(
2766
+ method='exhaustive',
2767
+ original_stats=original,
2768
+ final_stats=circuit.stats(best_weights),
2769
+ final_weights=best_weights,
2770
+ fitness=best_fitness,
2771
+ time_seconds=time.perf_counter() - start,
2772
+ metadata={
2773
+ 'search_space': search_space,
2774
+ 'tested': tested,
2775
+ 'valid_found': valid_found,
2776
+ 'search_range': search_range
2777
+ }
2778
+ )
2779
+
2780
+
2781
  def export_coq(circuit: AdaptiveCircuit, weights: Dict[str, torch.Tensor], cfg: Config) -> Path:
2782
  """
2783
  Export threshold circuit to Coq for formal verification.
 
2889
  ('gate_subst', cfg.run_gate_subst, lambda: prune_gate_substitution(circuit, evaluator, cfg)),
2890
  ('symmetry', cfg.run_symmetry, lambda: prune_symmetry(circuit, evaluator, cfg)),
2891
  ('fanin', cfg.run_fanin, lambda: prune_fanin(circuit, evaluator, cfg)),
2892
+ ('exhaustive', cfg.run_exhaustive, lambda: prune_exhaustive(circuit, evaluator, cfg)),
2893
  ('random', cfg.run_random, lambda: prune_random(circuit, evaluator, cfg)),
2894
  ('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
2895
  ('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
 
3098
  all_methods = ['magnitude', 'zero', 'quantize', 'evolutionary', 'annealing',
3099
  'neuron', 'lottery', 'topology', 'structured', 'sensitivity',
3100
  'weight_sharing', 'random', 'pareto', 'depth', 'gate_subst',
3101
+ 'symmetry', 'fanin', 'exhaustive']
3102
  for m in all_methods:
3103
  setattr(cfg, f'run_{m}', False)
3104
 
 
3121
  'depth': 'depth',
3122
  'gate': 'gate_subst', 'gate_subst': 'gate_subst', 'subst': 'gate_subst',
3123
  'sym': 'symmetry', 'symmetry': 'symmetry',
3124
+ 'fanin': 'fanin', 'fan': 'fanin',
3125
+ 'exhaustive': 'exhaustive', 'exh': 'exhaustive', 'brute': 'exhaustive'
3126
  }
3127
  if m in method_map:
3128
  setattr(cfg, f'run_{method_map[m]}', True)