CharlesCNorton commited on
Commit
1721000
·
1 Parent(s): 43d198d

Remove Coq export functionality from pruner

Browse files

- Remove export_coq config option and coq_output_dir
- Remove export_coq() function
- Remove --export-coq CLI argument
- Update method count in docstring (17 -> 15)

Files changed (1) hide show
  1. prune.py +298 -403
prune.py CHANGED
@@ -2,14 +2,13 @@
2
  Threshold Circuit Pruner
3
 
4
  Comprehensive pruning framework for threshold logic circuits.
5
- Supports 17 pruning methods with GPU-optimized parallel evaluation.
6
 
7
  Usage:
8
  python prune.py threshold-hamming74decoder
9
  python prune.py threshold-hamming74decoder --methods evo,depth,symmetry
10
  python prune.py --list
11
  python prune.py --all --max-inputs 8
12
- python prune.py threshold-xor --export-coq
13
  """
14
 
15
  import torch
@@ -126,24 +125,22 @@ class Config:
126
  run_quantize: bool = True
127
  run_evolutionary: bool = True
128
  run_annealing: bool = True
129
- run_neuron: bool = True
130
- run_lottery: bool = True
131
  run_topology: bool = True
132
- run_structured: bool = True
133
  run_sensitivity: bool = True
134
  run_weight_sharing: bool = True
135
- run_random: bool = True
136
- run_pareto: bool = True
137
  run_depth: bool = True
138
  run_gate_subst: bool = True
139
  run_symmetry: bool = True
140
  run_fanin: bool = True
141
- run_exhaustive: bool = True
 
142
 
143
  magnitude_passes: int = 100
144
 
145
  exhaustive_range: int = 2
146
  exhaustive_max_params: int = 12
 
147
 
148
  evo_generations: int = 2000
149
  evo_pop_size: int = 0
@@ -161,10 +158,6 @@ class Config:
161
  annealing_parallel_chains: int = 0
162
 
163
  quantize_targets: List[float] = field(default_factory=lambda: [-1.0, 0.0, 1.0])
164
- pareto_levels: List[float] = field(default_factory=lambda: [1.0, 0.99, 0.95, 0.90, 0.80])
165
-
166
- lottery_rounds: int = 10
167
- lottery_prune_rate: float = 0.2
168
 
169
  topology_generations: int = 500
170
  topology_remove_prob: float = 0.2
@@ -172,14 +165,9 @@ class Config:
172
 
173
  sensitivity_samples: int = 1000
174
 
175
- random_iterations: int = 10000
176
-
177
  depth_max_collapse: int = 3
178
  fanin_target: int = 4
179
 
180
- export_coq: bool = False
181
- coq_output_dir: Path = field(default_factory=lambda: Path('D:/threshold-pruner/coq_exports'))
182
-
183
 
184
  @dataclass
185
  class CircuitSpec:
@@ -260,6 +248,7 @@ class ComputationGraph:
260
  }
261
  self.layer_groups[depth].append(neuron_name)
262
 
 
263
  self._identify_outputs()
264
 
265
  def _estimate_depth(self, name: str) -> int:
@@ -285,6 +274,36 @@ class ComputationGraph:
285
 
286
  return depth
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  def _identify_outputs(self):
289
  """Identify which neurons are outputs based on n_outputs and topology."""
290
 
@@ -398,11 +417,17 @@ class ComputationGraph:
398
  return outputs
399
 
400
  def _get_neuron_input(self, neuron_name: str, activations: Dict, raw_input: torch.Tensor, expected_size: int) -> torch.Tensor:
401
- """Determine input for a neuron based on naming conventions."""
402
- if expected_size == self.n_inputs or expected_size == raw_input.shape[-1]:
 
 
403
  return raw_input
404
 
405
- parts = neuron_name.split('.')
 
 
 
 
406
 
407
  if 'layer2' in neuron_name:
408
  base = neuron_name.replace('.layer2', '')
@@ -1058,6 +1083,36 @@ class BatchedEvaluator:
1058
  if original_fitness < 0.999 and self.cfg.verbose:
1059
  print(f" [EVAL ERROR] Native eval fitness={original_fitness:.4f}")
1060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1061
  def _setup_vmap(self):
1062
  """Setup vmap-based parallel evaluation."""
1063
  try:
@@ -1123,24 +1178,25 @@ class BatchedEvaluator:
1123
 
1124
  input_size = w_shape[-1] if w_shape else 0
1125
 
1126
- input_source = 'raw'
1127
- input_neurons = []
1128
-
1129
- if input_size == self.n_inputs:
1130
- input_source = 'raw'
1131
- elif 'layer2' in neuron_name:
1132
- base = neuron_name.replace('.layer2', '')
1133
- or_key = f'{base}.layer1.or'
1134
- nand_key = f'{base}.layer1.nand'
1135
- if or_key in graph.neurons and nand_key in graph.neurons:
1136
- input_source = 'neurons'
1137
- input_neurons = [or_key, nand_key]
1138
- elif 'xor_final' in neuron_name:
1139
- prefix = neuron_name.split('.xor_final')[0]
1140
- candidates = [n for n in graph.neurons if n.startswith(prefix) and 'xor_' in n and 'final' not in n and 'layer2' in n]
1141
- if len(candidates) >= 2:
1142
- input_source = 'neurons'
1143
- input_neurons = sorted(candidates)[-2:]
 
1144
 
1145
  self.neuron_eval_order.append(neuron_name)
1146
  self.neuron_weight_slices[neuron_name] = {
@@ -1836,194 +1892,6 @@ def prune_annealing(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
1836
  )
1837
 
1838
 
1839
- def prune_neuron(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
1840
- """
1841
- Truly exhaustive neuron-level pruning.
1842
-
1843
- For each neuron, fix its parameters to zero and exhaustively search
1844
- all weight configurations for the REMAINING neurons. If any config
1845
- works, the neuron is provably unnecessary.
1846
- """
1847
- start = time.perf_counter()
1848
- original = circuit.stats()
1849
-
1850
- neuron_groups = defaultdict(list)
1851
- for key in circuit.weights.keys():
1852
- parts = key.rsplit('.', 1)
1853
- neuron_name = parts[0] if len(parts) == 2 else key.split('.')[0]
1854
- neuron_groups[neuron_name].append(key)
1855
-
1856
- neuron_names = list(neuron_groups.keys())
1857
- weight_keys = list(circuit.weights.keys())
1858
- weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
1859
- weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
1860
- n_total_params = sum(weight_sizes.values())
1861
-
1862
- if cfg.verbose:
1863
- print(f" Found {len(neuron_groups)} neuron groups")
1864
- print(f" Total parameters: {n_total_params}")
1865
-
1866
- search_range = cfg.exhaustive_range
1867
- values = list(range(-search_range, search_range + 1))
1868
-
1869
- def vector_to_weights(vec):
1870
- weights = {}
1871
- idx = 0
1872
- for k in weight_keys:
1873
- size = weight_sizes[k]
1874
- weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k])
1875
- idx += size
1876
- return weights
1877
-
1878
- def get_neuron_param_indices(neuron_name):
1879
- indices = []
1880
- idx = 0
1881
- for k in weight_keys:
1882
- size = weight_sizes[k]
1883
- if k in neuron_groups[neuron_name]:
1884
- indices.extend(range(idx, idx + size))
1885
- idx += size
1886
- return indices
1887
-
1888
- best_weights = circuit.clone_weights()
1889
- best_neurons_removed = 0
1890
- removed_neuron_names = []
1891
-
1892
- for neuron_name in neuron_names:
1893
- neuron_indices = set(get_neuron_param_indices(neuron_name))
1894
- other_indices = [i for i in range(n_total_params) if i not in neuron_indices]
1895
- n_other_params = len(other_indices)
1896
-
1897
- if n_other_params > cfg.exhaustive_max_params:
1898
- if cfg.verbose:
1899
- print(f" [{neuron_name}] Skipping: {n_other_params} remaining params > max {cfg.exhaustive_max_params}")
1900
- continue
1901
-
1902
- search_space = (2 * search_range + 1) ** n_other_params
1903
-
1904
- if cfg.verbose:
1905
- print(f" [{neuron_name}] Testing removal: searching {search_space:,} configs for {n_other_params} remaining params")
1906
-
1907
- found_valid = False
1908
- best_config = None
1909
- best_mag = float('inf')
1910
- tested = 0
1911
- report_interval = max(1, search_space // 10)
1912
-
1913
- for combo in product(values, repeat=n_other_params):
1914
- tested += 1
1915
-
1916
- full_vec = [0] * n_total_params
1917
- for i, val in zip(other_indices, combo):
1918
- full_vec[i] = val
1919
-
1920
- weights = vector_to_weights(full_vec)
1921
- fitness = evaluator.evaluate_single(weights)
1922
-
1923
- if fitness >= cfg.fitness_threshold:
1924
- mag = sum(abs(v) for v in combo)
1925
- if not found_valid or mag < best_mag:
1926
- found_valid = True
1927
- best_mag = mag
1928
- best_config = full_vec
1929
-
1930
- if cfg.verbose and tested % report_interval == 0:
1931
- elapsed = time.perf_counter() - start
1932
- pct = 100 * tested / search_space
1933
- print(f" [{elapsed:6.1f}s] {pct:5.1f}% | valid={'YES' if found_valid else 'no '} | best_mag={best_mag if found_valid else '-'}")
1934
-
1935
- if found_valid:
1936
- if cfg.verbose:
1937
- print(f" [REMOVABLE] {neuron_name} can be removed! Best remaining mag={best_mag}")
1938
- removed_neuron_names.append(neuron_name)
1939
- best_neurons_removed += 1
1940
- best_weights = vector_to_weights(best_config)
1941
- else:
1942
- if cfg.verbose:
1943
- print(f" [REQUIRED] {neuron_name} is necessary")
1944
-
1945
- if cfg.verbose:
1946
- elapsed = time.perf_counter() - start
1947
- print(f" [NEURON COMPLETE] {best_neurons_removed} neurons removable out of {len(neuron_names)}")
1948
- print(f" Removable: {removed_neuron_names if removed_neuron_names else 'none'}")
1949
- print(f" Time: {elapsed:.1f}s")
1950
-
1951
- return PruneResult(
1952
- method='neuron',
1953
- original_stats=original,
1954
- final_stats=circuit.stats(best_weights),
1955
- final_weights=best_weights,
1956
- fitness=evaluator.evaluate_single(best_weights),
1957
- time_seconds=time.perf_counter() - start,
1958
- metadata={'neurons_removed': best_neurons_removed, 'removed_names': removed_neuron_names}
1959
- )
1960
-
1961
-
1962
- def prune_lottery(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
1963
- """Lottery Ticket pruning."""
1964
- start = time.perf_counter()
1965
- original = circuit.stats()
1966
-
1967
- weights = circuit.clone_weights()
1968
- initial = circuit.clone_weights()
1969
- history = []
1970
-
1971
- if cfg.verbose:
1972
- print(f" Lottery: {cfg.lottery_rounds} rounds, {cfg.lottery_prune_rate * 100:.0f}% per round")
1973
-
1974
- mask = {k: torch.ones_like(v) for k, v in weights.items()}
1975
-
1976
- for rnd in range(cfg.lottery_rounds):
1977
- all_weights = []
1978
- for name, tensor in weights.items():
1979
- flat = tensor.flatten()
1980
- m_flat = mask[name].flatten()
1981
- for i in range(len(flat)):
1982
- if m_flat[i] > 0 and flat[i].item() != 0:
1983
- all_weights.append((abs(flat[i].item()), name, i))
1984
-
1985
- if not all_weights:
1986
- break
1987
-
1988
- all_weights.sort(key=lambda x: x[0])
1989
- n_prune = max(1, int(len(all_weights) * cfg.lottery_prune_rate))
1990
- to_prune = all_weights[:n_prune]
1991
-
1992
- for _, name, idx in to_prune:
1993
- m_flat = mask[name].flatten()
1994
- m_flat[idx] = 0
1995
- mask[name] = m_flat.view(mask[name].shape)
1996
-
1997
- for name in weights:
1998
- weights[name] = initial[name] * mask[name]
1999
-
2000
- fitness = evaluator.evaluate_single(weights)
2001
- stats = circuit.stats(weights)
2002
- history.append({'round': rnd, 'pruned': n_prune, 'fitness': fitness, 'magnitude': stats['magnitude']})
2003
-
2004
- if cfg.verbose:
2005
- print(f" Round {rnd}: pruned {n_prune}, fitness={fitness:.4f}")
2006
-
2007
- if fitness < cfg.fitness_threshold:
2008
- for _, name, idx in to_prune:
2009
- m_flat = mask[name].flatten()
2010
- m_flat[idx] = 1
2011
- mask[name] = m_flat.view(mask[name].shape)
2012
- for name in weights:
2013
- weights[name] = initial[name] * mask[name]
2014
- break
2015
-
2016
- return PruneResult(
2017
- method='lottery',
2018
- original_stats=original,
2019
- final_stats=circuit.stats(weights),
2020
- final_weights=weights,
2021
- fitness=evaluator.evaluate_single(weights),
2022
- time_seconds=time.perf_counter() - start,
2023
- history=history
2024
- )
2025
-
2026
-
2027
  def prune_topology(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2028
  """Topology search - remove connection groups."""
2029
  start = time.perf_counter()
@@ -2092,14 +1960,14 @@ def prune_topology(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: C
2092
  )
2093
 
2094
 
2095
- def prune_structured(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2096
- """Structured pruning - remove entire rows/columns of weight matrices."""
2097
  start = time.perf_counter()
2098
  weights = circuit.clone_weights()
2099
  original = circuit.stats(weights)
2100
 
2101
  if cfg.verbose:
2102
- print(f" Structured pruning (rows/columns)...")
2103
 
2104
  removed = 0
2105
 
@@ -2137,7 +2005,7 @@ def prune_structured(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
2137
  print(f" Removed {removed} rows/columns")
2138
 
2139
  return PruneResult(
2140
- method='structured',
2141
  original_stats=original,
2142
  final_stats=circuit.stats(weights),
2143
  final_weights=weights,
@@ -2279,99 +2147,6 @@ def prune_weight_sharing(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator,
2279
  )
2280
 
2281
 
2282
- def prune_random(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2283
- """Random search baseline."""
2284
- start = time.perf_counter()
2285
- original = circuit.stats()
2286
-
2287
- if cfg.verbose:
2288
- print(f" Random search ({cfg.random_iterations:,} iterations)...")
2289
-
2290
- base_vector = circuit.weights_to_vector(circuit.weights)
2291
- best_weights = circuit.clone_weights()
2292
- best_mag = sum(t.abs().sum().item() for t in best_weights.values())
2293
- best_fitness = evaluator.evaluate_single(best_weights)
2294
-
2295
- n_valid = 0
2296
-
2297
- batch_size = min(10000, evaluator.max_batch)
2298
- n_batches = cfg.random_iterations // batch_size
2299
-
2300
- for batch in range(n_batches):
2301
- population = base_vector.unsqueeze(0).expand(batch_size, -1).clone()
2302
- noise = torch.randn_like(population) * 2
2303
- population = (population + noise).round()
2304
-
2305
- fitness = evaluator.evaluate_population(population)
2306
-
2307
- valid_mask = fitness >= cfg.fitness_threshold
2308
- n_valid += valid_mask.sum().item()
2309
-
2310
- if valid_mask.any():
2311
- magnitudes = population.abs().sum(dim=1)
2312
- magnitudes[~valid_mask] = float('inf')
2313
- best_idx = magnitudes.argmin().item()
2314
-
2315
- if magnitudes[best_idx] < best_mag:
2316
- best_mag = magnitudes[best_idx].item()
2317
- best_weights = circuit.vector_to_weights(population[best_idx])
2318
- best_fitness = fitness[best_idx].item()
2319
-
2320
- if cfg.verbose and batch % 10 == 0:
2321
- print(f" Batch {batch}/{n_batches}: valid={n_valid}, best_mag={best_mag:.0f}")
2322
-
2323
- return PruneResult(
2324
- method='random',
2325
- original_stats=original,
2326
- final_stats=circuit.stats(best_weights),
2327
- final_weights=best_weights,
2328
- fitness=best_fitness,
2329
- time_seconds=time.perf_counter() - start,
2330
- metadata={'total_valid': n_valid, 'total_tested': cfg.random_iterations}
2331
- )
2332
-
2333
-
2334
- def prune_pareto(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2335
- """Explore Pareto frontier of correctness vs. size."""
2336
- start = time.perf_counter()
2337
- original = circuit.stats()
2338
- frontier = []
2339
-
2340
- if cfg.verbose:
2341
- print(f" Exploring Pareto frontier...")
2342
-
2343
- for target in cfg.pareto_levels:
2344
- relaxed_cfg = Config(
2345
- device=cfg.device,
2346
- fitness_threshold=target,
2347
- magnitude_passes=30,
2348
- verbose=False,
2349
- vram=cfg.vram
2350
- )
2351
-
2352
- result = prune_magnitude(circuit, evaluator, relaxed_cfg)
2353
-
2354
- frontier.append({
2355
- 'target': target,
2356
- 'actual': result.fitness,
2357
- 'magnitude': result.final_stats['magnitude'],
2358
- 'nonzero': result.final_stats['nonzero']
2359
- })
2360
-
2361
- if cfg.verbose:
2362
- print(f" Target {target:.2f}: fitness={result.fitness:.4f}, mag={result.final_stats['magnitude']:.0f}")
2363
-
2364
- return PruneResult(
2365
- method='pareto',
2366
- original_stats=original,
2367
- final_stats=frontier[-1] if frontier else original,
2368
- final_weights=circuit.clone_weights(),
2369
- fitness=frontier[0]['actual'] if frontier else 1.0,
2370
- time_seconds=time.perf_counter() - start,
2371
- history=frontier
2372
- )
2373
-
2374
-
2375
  def prune_depth(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2376
  """
2377
  Depth reduction - attempt to collapse consecutive layers.
@@ -2790,9 +2565,9 @@ def _configs_at_magnitude(mag: int, n_params: int):
2790
  yield tuple(signed)
2791
 
2792
 
2793
- def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2794
  """
2795
- Exhaustive search by magnitude level - finds provably optimal solutions.
2796
 
2797
  Searches magnitude 0, then 1, then 2, ... until valid solutions found.
2798
  Returns ALL valid solutions at the minimum magnitude (to discover families).
@@ -2924,7 +2699,7 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
2924
  print(f" - Time: {elapsed:.1f}s")
2925
 
2926
  return PruneResult(
2927
- method='exhaustive',
2928
  original_stats=original,
2929
  final_stats=circuit.stats(best_weights),
2930
  final_weights=best_weights,
@@ -2939,71 +2714,194 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
2939
  )
2940
 
2941
 
2942
- def export_coq(circuit: AdaptiveCircuit, weights: Dict[str, torch.Tensor], cfg: Config) -> Path:
2943
  """
2944
- Export threshold circuit to Coq for formal verification.
2945
- Generates a Coq file with the circuit definition and correctness theorem.
2946
  """
2947
- cfg.coq_output_dir.mkdir(parents=True, exist_ok=True)
2948
-
2949
- name = circuit.spec.name.replace('-', '_')
2950
- output_path = cfg.coq_output_dir / f'{name}.v'
2951
-
2952
- lines = [
2953
- f'(* Threshold circuit: {circuit.spec.name} *)',
2954
- f'(* Auto-generated by threshold-pruner *)',
2955
- f'(* Inputs: {circuit.spec.inputs}, Outputs: {circuit.spec.outputs} *)',
2956
- '',
2957
- 'Require Import ZArith.',
2958
- 'Require Import List.',
2959
- 'Import ListNotations.',
2960
- '',
2961
- '(* Threshold gate: output 1 if weighted sum >= 0 *)',
2962
- 'Definition threshold (weights : list Z) (bias : Z) (inputs : list Z) : Z :=',
2963
- ' let sum := fold_left (fun acc wb => acc + fst wb * snd wb) ',
2964
- ' (combine weights inputs) bias in',
2965
- ' if Z.geb sum 0 then 1 else 0.',
2966
- '',
2967
- f'(* Circuit definition for {name} *)',
2968
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2969
 
2970
- neuron_defs = []
2971
- for neuron_name in circuit.graph.neuron_order:
2972
- info = circuit.graph.neurons[neuron_name]
2973
- w_key = info.get('weight_key')
2974
- b_key = info.get('bias_key')
2975
 
2976
- if w_key and w_key in weights:
2977
- w = weights[w_key].flatten().tolist()
2978
- b = weights[b_key].item() if b_key and b_key in weights else 0
 
 
 
 
2979
 
2980
- w_str = '[' + '; '.join(str(int(x)) for x in w) + ']'
2981
- safe_name = neuron_name.replace('.', '_').replace('-', '_')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2982
 
2983
- neuron_defs.append(f'Definition {safe_name}_weights : list Z := {w_str}.')
2984
- neuron_defs.append(f'Definition {safe_name}_bias : Z := {int(b)}.')
2985
- neuron_defs.append('')
 
2986
 
2987
- lines.extend(neuron_defs)
2988
 
2989
- lines.extend([
2990
- '',
2991
- f'(* Correctness theorem placeholder *)',
2992
- f'(* To be completed with specific input/output verification *)',
2993
- f'Theorem {name}_correct : forall inputs,',
2994
- f' length inputs = {circuit.spec.inputs}%nat ->',
2995
- f' (* output matches expected *)',
2996
- f' True.',
2997
- 'Proof.',
2998
- ' intros.',
2999
- ' (* Proof to be completed *)',
3000
- 'Admitted.',
3001
- ])
3002
 
3003
- with open(output_path, 'w', encoding='utf-8') as f:
3004
- f.write('\n'.join(lines))
 
 
 
 
 
 
3005
 
3006
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
3007
 
3008
 
3009
  def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]:
@@ -3040,21 +2938,18 @@ def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneRes
3040
  ('magnitude', cfg.run_magnitude, lambda: prune_magnitude(circuit, evaluator, cfg)),
3041
  ('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)),
3042
  ('quantize', cfg.run_quantize, lambda: prune_quantize(circuit, evaluator, cfg)),
3043
- ('neuron', cfg.run_neuron, lambda: prune_neuron(circuit, evaluator, cfg)),
3044
- ('lottery', cfg.run_lottery, lambda: prune_lottery(circuit, evaluator, cfg)),
3045
  ('topology', cfg.run_topology, lambda: prune_topology(circuit, evaluator, cfg)),
3046
- ('structured', cfg.run_structured, lambda: prune_structured(circuit, evaluator, cfg)),
3047
  ('sensitivity', cfg.run_sensitivity, lambda: prune_sensitivity(circuit, evaluator, cfg)),
3048
  ('weight_sharing', cfg.run_weight_sharing, lambda: prune_weight_sharing(circuit, evaluator, cfg)),
3049
  ('depth', cfg.run_depth, lambda: prune_depth(circuit, evaluator, cfg)),
3050
  ('gate_subst', cfg.run_gate_subst, lambda: prune_gate_substitution(circuit, evaluator, cfg)),
3051
  ('symmetry', cfg.run_symmetry, lambda: prune_symmetry(circuit, evaluator, cfg)),
3052
  ('fanin', cfg.run_fanin, lambda: prune_fanin(circuit, evaluator, cfg)),
3053
- ('exhaustive', cfg.run_exhaustive, lambda: prune_exhaustive(circuit, evaluator, cfg)),
3054
- ('random', cfg.run_random, lambda: prune_random(circuit, evaluator, cfg)),
3055
  ('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
3056
  ('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
3057
- ('pareto', cfg.run_pareto, lambda: prune_pareto(circuit, evaluator, cfg)),
3058
  ]
3059
 
3060
  enabled_methods = [(name, fn) for name, enabled, fn in methods if enabled]
@@ -3100,10 +2995,6 @@ def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneRes
3100
  reduction = 1 - best_mag / original['magnitude']
3101
  print(f"\n BEST: {best_method} ({reduction * 100:.1f}% magnitude reduction)")
3102
 
3103
- if cfg.export_coq:
3104
- coq_path = export_coq(circuit, results[best_method].final_weights, cfg)
3105
- print(f" Coq export: {coq_path}")
3106
-
3107
  return results
3108
 
3109
 
@@ -3228,8 +3119,9 @@ def main():
3228
  parser.add_argument('--sa-iters', type=int, default=50000, help='Simulated annealing iterations')
3229
  parser.add_argument('--sa-chains', type=int, default=0, help='Parallel SA chains (0=auto)')
3230
  parser.add_argument('--vram-target', type=float, default=0.75)
3231
- parser.add_argument('--export-coq', action='store_true')
3232
  parser.add_argument('--fanin-target', type=int, default=4)
 
 
3233
 
3234
  args = parser.parse_args()
3235
 
@@ -3251,15 +3143,16 @@ def main():
3251
  evo_generations=args.evo_gens,
3252
  annealing_iterations=args.sa_iters,
3253
  annealing_parallel_chains=args.sa_chains,
3254
- export_coq=args.export_coq,
3255
- fanin_target=args.fanin_target
 
3256
  )
3257
 
3258
  if args.methods:
3259
  all_methods = ['magnitude', 'zero', 'quantize', 'evolutionary', 'annealing',
3260
- 'neuron', 'lottery', 'topology', 'structured', 'sensitivity',
3261
- 'weight_sharing', 'random', 'pareto', 'depth', 'gate_subst',
3262
- 'symmetry', 'fanin', 'exhaustive']
3263
  for m in all_methods:
3264
  setattr(cfg, f'run_{m}', False)
3265
 
@@ -3271,19 +3164,16 @@ def main():
3271
  'quant': 'quantize', 'quantize': 'quantize',
3272
  'evo': 'evolutionary', 'evolutionary': 'evolutionary',
3273
  'anneal': 'annealing', 'annealing': 'annealing', 'sa': 'annealing',
3274
- 'neuron': 'neuron',
3275
- 'lottery': 'lottery',
3276
  'topo': 'topology', 'topology': 'topology',
3277
- 'struct': 'structured', 'structured': 'structured',
3278
  'sens': 'sensitivity', 'sensitivity': 'sensitivity',
3279
  'share': 'weight_sharing', 'weight_sharing': 'weight_sharing', 'sharing': 'weight_sharing',
3280
- 'rand': 'random', 'random': 'random',
3281
- 'pareto': 'pareto',
3282
  'depth': 'depth',
3283
  'gate': 'gate_subst', 'gate_subst': 'gate_subst', 'subst': 'gate_subst',
3284
  'sym': 'symmetry', 'symmetry': 'symmetry',
3285
  'fanin': 'fanin', 'fan': 'fanin',
3286
- 'exhaustive': 'exhaustive', 'exh': 'exhaustive', 'brute': 'exhaustive'
 
3287
  }
3288
  if m in method_map:
3289
  setattr(cfg, f'run_{method_map[m]}', True)
@@ -3343,6 +3233,11 @@ def main():
3343
  print(" python prune.py threshold-hamming74decoder --methods evo")
3344
  print(" python prune.py threshold-xor --methods evo --evo-pop 500000 --evo-gens 5000")
3345
  print("")
 
 
 
 
 
3346
  print(" # Pipeline mode (chained, each stage feeds into next):")
3347
  print(" python prune.py threshold-hamming74decoder --pipeline evo,mag,zero,quant --save")
3348
  print(" python prune.py threshold-xor --pipeline anneal,mag,zero --sa-iters 100000")
 
2
  Threshold Circuit Pruner
3
 
4
  Comprehensive pruning framework for threshold logic circuits.
5
+ Supports 15 pruning methods with GPU-optimized parallel evaluation.
6
 
7
  Usage:
8
  python prune.py threshold-hamming74decoder
9
  python prune.py threshold-hamming74decoder --methods evo,depth,symmetry
10
  python prune.py --list
11
  python prune.py --all --max-inputs 8
 
12
  """
13
 
14
  import torch
 
125
  run_quantize: bool = True
126
  run_evolutionary: bool = True
127
  run_annealing: bool = True
128
+ run_structural: bool = True
 
129
  run_topology: bool = True
 
130
  run_sensitivity: bool = True
131
  run_weight_sharing: bool = True
 
 
132
  run_depth: bool = True
133
  run_gate_subst: bool = True
134
  run_symmetry: bool = True
135
  run_fanin: bool = True
136
+ run_exhaustive_mag: bool = True
137
+ run_exhaustive_sparse: bool = True
138
 
139
  magnitude_passes: int = 100
140
 
141
  exhaustive_range: int = 2
142
  exhaustive_max_params: int = 12
143
+ sparse_max_weight: int = 3
144
 
145
  evo_generations: int = 2000
146
  evo_pop_size: int = 0
 
158
  annealing_parallel_chains: int = 0
159
 
160
  quantize_targets: List[float] = field(default_factory=lambda: [-1.0, 0.0, 1.0])
 
 
 
 
161
 
162
  topology_generations: int = 500
163
  topology_remove_prob: float = 0.2
 
165
 
166
  sensitivity_samples: int = 1000
167
 
 
 
168
  depth_max_collapse: int = 3
169
  fanin_target: int = 4
170
 
 
 
 
171
 
172
  @dataclass
173
  class CircuitSpec:
 
248
  }
249
  self.layer_groups[depth].append(neuron_name)
250
 
251
+ self._infer_depth_from_shapes()
252
  self._identify_outputs()
253
 
254
  def _estimate_depth(self, name: str) -> int:
 
274
 
275
  return depth
276
 
277
+ def _infer_depth_from_shapes(self):
278
+ """Second pass: infer depth from weight shapes when naming is ambiguous."""
279
+ neurons_at_depth_0 = []
280
+ neurons_needing_inference = []
281
+
282
+ for name, info in self.neurons.items():
283
+ input_size = info.get('input_size', 0)
284
+ if input_size == self.n_inputs:
285
+ info['depth'] = 0
286
+ info['input_source'] = 'raw'
287
+ info['input_neurons'] = []
288
+ neurons_at_depth_0.append(name)
289
+ elif input_size > 0 and input_size != self.n_inputs:
290
+ neurons_needing_inference.append((name, input_size))
291
+
292
+ for name, input_size in neurons_needing_inference:
293
+ if input_size == len(neurons_at_depth_0):
294
+ self.neurons[name]['depth'] = 1
295
+ self.neurons[name]['input_source'] = 'neurons'
296
+ self.neurons[name]['input_neurons'] = sorted(neurons_at_depth_0)
297
+ elif input_size < len(neurons_at_depth_0) and input_size > 0:
298
+ candidates = sorted(neurons_at_depth_0)[:input_size]
299
+ self.neurons[name]['depth'] = 1
300
+ self.neurons[name]['input_source'] = 'neurons'
301
+ self.neurons[name]['input_neurons'] = candidates
302
+
303
+ self.layer_groups = defaultdict(list)
304
+ for name, info in self.neurons.items():
305
+ self.layer_groups[info['depth']].append(name)
306
+
307
  def _identify_outputs(self):
308
  """Identify which neurons are outputs based on n_outputs and topology."""
309
 
 
417
  return outputs
418
 
419
  def _get_neuron_input(self, neuron_name: str, activations: Dict, raw_input: torch.Tensor, expected_size: int) -> torch.Tensor:
420
+ """Determine input for a neuron based on topology inference or naming conventions."""
421
+ info = self.neurons.get(neuron_name, {})
422
+
423
+ if info.get('input_source') == 'raw' or expected_size == self.n_inputs or expected_size == raw_input.shape[-1]:
424
  return raw_input
425
 
426
+ if info.get('input_source') == 'neurons' and info.get('input_neurons'):
427
+ input_neurons = info['input_neurons']
428
+ vals = [activations[n] for n in input_neurons if n in activations]
429
+ if len(vals) == len(input_neurons):
430
+ return torch.stack(vals, dim=-1)
431
 
432
  if 'layer2' in neuron_name:
433
  base = neuron_name.replace('.layer2', '')
 
1083
  if original_fitness < 0.999 and self.cfg.verbose:
1084
  print(f" [EVAL ERROR] Native eval fitness={original_fitness:.4f}")
1085
 
1086
+ if self.batched_ready and not self.use_native_eval:
1087
+ try:
1088
+ test_vecs = []
1089
+ base = self.circuit.base_vector.clone()
1090
+ test_vecs.append(base)
1091
+ for _ in range(3):
1092
+ perturbed = base.clone()
1093
+ mask = torch.rand_like(perturbed) < 0.3
1094
+ perturbed[mask] = torch.randint(-2, 3, (mask.sum().item(),), device=self.device, dtype=torch.float32)
1095
+ test_vecs.append(perturbed)
1096
+
1097
+ test_pop = torch.stack(test_vecs)
1098
+ single_results = torch.tensor([
1099
+ self.evaluate_single(self.circuit.vector_to_weights(v)) for v in test_vecs
1100
+ ], device=self.device)
1101
+
1102
+ batched_results = self._evaluate_batched(test_pop)
1103
+
1104
+ if not torch.allclose(single_results, batched_results, atol=0.01):
1105
+ if self.cfg.verbose:
1106
+ print(f" [EVAL WARNING] Batched eval mismatch (single={single_results.tolist()}, batched={batched_results.tolist()})")
1107
+ print(f" [EVAL WARNING] Falling back to native eval")
1108
+ self.use_native_eval = self.circuit.has_native
1109
+ self.batched_ready = False
1110
+ except Exception as e:
1111
+ if self.cfg.verbose:
1112
+ print(f" [EVAL WARNING] Batched eval failed ({e}), falling back to native eval")
1113
+ self.use_native_eval = self.circuit.has_native
1114
+ self.batched_ready = False
1115
+
1116
  def _setup_vmap(self):
1117
  """Setup vmap-based parallel evaluation."""
1118
  try:
 
1178
 
1179
  input_size = w_shape[-1] if w_shape else 0
1180
 
1181
+ input_source = info.get('input_source', 'raw')
1182
+ input_neurons = info.get('input_neurons', [])
1183
+
1184
+ if input_source == 'raw' and input_size != self.n_inputs:
1185
+ if input_size == self.n_inputs:
1186
+ input_source = 'raw'
1187
+ elif 'layer2' in neuron_name:
1188
+ base = neuron_name.replace('.layer2', '')
1189
+ or_key = f'{base}.layer1.or'
1190
+ nand_key = f'{base}.layer1.nand'
1191
+ if or_key in graph.neurons and nand_key in graph.neurons:
1192
+ input_source = 'neurons'
1193
+ input_neurons = [or_key, nand_key]
1194
+ elif 'xor_final' in neuron_name:
1195
+ prefix = neuron_name.split('.xor_final')[0]
1196
+ candidates = [n for n in graph.neurons if n.startswith(prefix) and 'xor_' in n and 'final' not in n and 'layer2' in n]
1197
+ if len(candidates) >= 2:
1198
+ input_source = 'neurons'
1199
+ input_neurons = sorted(candidates)[-2:]
1200
 
1201
  self.neuron_eval_order.append(neuron_name)
1202
  self.neuron_weight_slices[neuron_name] = {
 
1892
  )
1893
 
1894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1895
  def prune_topology(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
1896
  """Topology search - remove connection groups."""
1897
  start = time.perf_counter()
 
1960
  )
1961
 
1962
 
1963
+ def prune_structural(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
1964
+ """Structural pruning - remove entire rows/columns of weight matrices."""
1965
  start = time.perf_counter()
1966
  weights = circuit.clone_weights()
1967
  original = circuit.stats(weights)
1968
 
1969
  if cfg.verbose:
1970
+ print(f" Structural pruning (rows/columns)...")
1971
 
1972
  removed = 0
1973
 
 
2005
  print(f" Removed {removed} rows/columns")
2006
 
2007
  return PruneResult(
2008
+ method='structural',
2009
  original_stats=original,
2010
  final_stats=circuit.stats(weights),
2011
  final_weights=weights,
 
2147
  )
2148
 
2149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2150
  def prune_depth(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2151
  """
2152
  Depth reduction - attempt to collapse consecutive layers.
 
2565
  yield tuple(signed)
2566
 
2567
 
2568
+ def prune_exhaustive_mag(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2569
  """
2570
+ Exhaustive search by magnitude level - finds provably minimum-magnitude solutions.
2571
 
2572
  Searches magnitude 0, then 1, then 2, ... until valid solutions found.
2573
  Returns ALL valid solutions at the minimum magnitude (to discover families).
 
2699
  print(f" - Time: {elapsed:.1f}s")
2700
 
2701
  return PruneResult(
2702
+ method='exhaustive_mag',
2703
  original_stats=original,
2704
  final_stats=circuit.stats(best_weights),
2705
  final_weights=best_weights,
 
2714
  )
2715
 
2716
 
2717
+ def _configs_with_k_nonzeros(k: int, n_params: int, max_weight: int):
2718
  """
2719
+ Generate all n_params-length configs with exactly k nonzero values.
2720
+ Nonzero values range from -max_weight to +max_weight (excluding 0).
2721
  """
2722
+ if k > n_params or k < 0:
2723
+ return
2724
+
2725
+ nonzero_vals = list(range(-max_weight, 0)) + list(range(1, max_weight + 1))
2726
+
2727
+ for positions in combinations(range(n_params), k):
2728
+ position_set = set(positions)
2729
+ for vals in product(nonzero_vals, repeat=k):
2730
+ config = [0] * n_params
2731
+ for pos, val in zip(positions, vals):
2732
+ config[pos] = val
2733
+ yield tuple(config)
2734
+
2735
+
2736
+ def prune_exhaustive_sparse(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
2737
+ """
2738
+ Exhaustive search by sparsity level - finds provably maximum-sparsity solutions.
2739
+
2740
+ Searches from 1 nonzero, then 2, then 3, ... until valid solutions found.
2741
+ Returns ALL valid solutions at the minimum nonzero count (to discover families).
2742
+
2743
+ Useful for hardware where connection count matters more than weight magnitude.
2744
+ """
2745
+ start = time.perf_counter()
2746
+ original = circuit.stats()
2747
+
2748
+ n_params = original['total']
2749
+ original_nonzeros = original['nonzero']
2750
+ max_weight = cfg.sparse_max_weight
2751
+
2752
+ if n_params > cfg.exhaustive_max_params:
2753
+ if cfg.verbose:
2754
+ print(f" [SPARSE] Skipping: {n_params} params exceeds max {cfg.exhaustive_max_params}")
2755
+ return PruneResult(
2756
+ method='exhaustive_sparse',
2757
+ original_stats=original,
2758
+ final_stats=original,
2759
+ final_weights=circuit.clone_weights(),
2760
+ fitness=evaluator.evaluate_single(circuit.weights),
2761
+ time_seconds=time.perf_counter() - start,
2762
+ metadata={'skipped': True, 'reason': 'too_many_params'}
2763
+ )
2764
+
2765
+ if cfg.verbose:
2766
+ print(f" [SPARSE] Parameters: {n_params}")
2767
+ print(f" [SPARSE] Original nonzeros: {original_nonzeros}")
2768
+ print(f" [SPARSE] Max weight magnitude: {max_weight}")
2769
+ print(f" [SPARSE] Searching by nonzero count (1, 2, 3, ...)")
2770
+
2771
+ weight_keys = list(circuit.weights.keys())
2772
+ weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
2773
+ weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
2774
+
2775
+ def vector_to_weights(vec):
2776
+ weights = {}
2777
+ idx = 0
2778
+ for k in weight_keys:
2779
+ size = weight_sizes[k]
2780
+ weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k])
2781
+ idx += size
2782
+ return weights
2783
+
2784
+ total_tested = 0
2785
+ all_solutions = []
2786
+ optimal_nonzeros = None
2787
+
2788
+ for n_nonzero in range(1, n_params + 1):
2789
+ nz_start = time.perf_counter()
2790
+
2791
+ n_positions = math.comb(n_params, n_nonzero)
2792
+ n_value_combos = (2 * max_weight) ** n_nonzero
2793
+ n_configs = n_positions * n_value_combos
2794
+
2795
+ if cfg.verbose:
2796
+ print(f" Nonzeros {n_nonzero}: {n_configs:,} configurations...", end=" ", flush=True)
2797
+
2798
+ valid_at_nz = []
2799
+
2800
+ batch_configs = []
2801
+ batch_size = min(100000, n_configs)
2802
+
2803
+ for config in _configs_with_k_nonzeros(n_nonzero, n_params, max_weight):
2804
+ batch_configs.append(config)
2805
 
2806
+ if len(batch_configs) >= batch_size:
2807
+ population = torch.tensor(batch_configs, dtype=torch.float32, device=cfg.device)
 
 
 
2808
 
2809
+ try:
2810
+ fitness_batch = evaluator.evaluate_population(population)
2811
+ except:
2812
+ fitness_batch = torch.tensor([
2813
+ evaluator.evaluate_single(vector_to_weights(c))
2814
+ for c in batch_configs
2815
+ ], device=cfg.device)
2816
 
2817
+ valid_mask = fitness_batch >= cfg.fitness_threshold
2818
+ for i, is_valid in enumerate(valid_mask.tolist()):
2819
+ if is_valid:
2820
+ valid_at_nz.append(batch_configs[i])
2821
+
2822
+ total_tested += len(batch_configs)
2823
+ batch_configs = []
2824
+
2825
+ if batch_configs:
2826
+ population = torch.tensor(batch_configs, dtype=torch.float32, device=cfg.device)
2827
+
2828
+ try:
2829
+ fitness_batch = evaluator.evaluate_population(population)
2830
+ except:
2831
+ fitness_batch = torch.tensor([
2832
+ evaluator.evaluate_single(vector_to_weights(c))
2833
+ for c in batch_configs
2834
+ ], device=cfg.device)
2835
+
2836
+ valid_mask = fitness_batch >= cfg.fitness_threshold
2837
+ for i, is_valid in enumerate(valid_mask.tolist()):
2838
+ if is_valid:
2839
+ valid_at_nz.append(batch_configs[i])
2840
+
2841
+ total_tested += len(batch_configs)
2842
+
2843
+ nz_time = time.perf_counter() - nz_start
2844
+
2845
+ if valid_at_nz:
2846
+ if cfg.verbose:
2847
+ print(f"FOUND {len(valid_at_nz)} solutions! ({nz_time:.2f}s)")
2848
+
2849
+ optimal_nonzeros = n_nonzero
2850
+ all_solutions = valid_at_nz
2851
+
2852
+ if cfg.verbose:
2853
+ print(f" [SPARSE] Optimal nonzeros: {optimal_nonzeros}")
2854
+ print(f" [SPARSE] Solutions found: {len(all_solutions)}")
2855
+ print(f" [SPARSE] Solution analysis:")
2856
+ print(f" {'#':<3} {'NZ':<4} {'Mag':<6} {'Max|w|':<7} {'Weights'}")
2857
+ print(f" {'-'*60}")
2858
+ for i, sol in enumerate(all_solutions[:20]):
2859
+ nz = sum(1 for v in sol if v != 0)
2860
+ mag = sum(abs(v) for v in sol)
2861
+ max_w = max(abs(v) for v in sol) if any(v != 0 for v in sol) else 0
2862
+ print(f" {i+1:<3} {nz:<4} {mag:<6} {max_w:<7} {sol}")
2863
+ if len(all_solutions) > 20:
2864
+ print(f" ... and {len(all_solutions) - 20} more")
2865
 
2866
+ break
2867
+ else:
2868
+ if cfg.verbose:
2869
+ print(f"none ({nz_time:.2f}s)")
2870
 
2871
+ elapsed = time.perf_counter() - start
2872
 
2873
+ if all_solutions:
2874
+ best_combo = min(all_solutions, key=lambda x: sum(abs(v) for v in x))
2875
+ best_weights = vector_to_weights(best_combo)
2876
+ best_fitness = evaluator.evaluate_single(best_weights)
2877
+ else:
2878
+ best_weights = circuit.clone_weights()
2879
+ best_fitness = evaluator.evaluate_single(best_weights)
2880
+ optimal_nonzeros = original_nonzeros
 
 
 
 
 
2881
 
2882
+ if cfg.verbose:
2883
+ final_stats = circuit.stats(best_weights)
2884
+ print(f" [SPARSE COMPLETE]")
2885
+ print(f" - Configurations tested: {total_tested:,}")
2886
+ print(f" - Optimal nonzeros: {optimal_nonzeros} (original: {original_nonzeros})")
2887
+ print(f" - Total solutions at optimal: {len(all_solutions)}")
2888
+ print(f" - Sparsity: {(1 - optimal_nonzeros/n_params)*100:.1f}%")
2889
+ print(f" - Time: {elapsed:.1f}s")
2890
 
2891
+ return PruneResult(
2892
+ method='exhaustive_sparse',
2893
+ original_stats=original,
2894
+ final_stats=circuit.stats(best_weights),
2895
+ final_weights=best_weights,
2896
+ fitness=best_fitness,
2897
+ time_seconds=elapsed,
2898
+ metadata={
2899
+ 'optimal_nonzeros': optimal_nonzeros,
2900
+ 'total_tested': total_tested,
2901
+ 'solutions_count': len(all_solutions),
2902
+ 'all_solutions': all_solutions[:100]
2903
+ }
2904
+ )
2905
 
2906
 
2907
  def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]:
 
2938
  ('magnitude', cfg.run_magnitude, lambda: prune_magnitude(circuit, evaluator, cfg)),
2939
  ('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)),
2940
  ('quantize', cfg.run_quantize, lambda: prune_quantize(circuit, evaluator, cfg)),
2941
+ ('structural', cfg.run_structural, lambda: prune_structural(circuit, evaluator, cfg)),
 
2942
  ('topology', cfg.run_topology, lambda: prune_topology(circuit, evaluator, cfg)),
 
2943
  ('sensitivity', cfg.run_sensitivity, lambda: prune_sensitivity(circuit, evaluator, cfg)),
2944
  ('weight_sharing', cfg.run_weight_sharing, lambda: prune_weight_sharing(circuit, evaluator, cfg)),
2945
  ('depth', cfg.run_depth, lambda: prune_depth(circuit, evaluator, cfg)),
2946
  ('gate_subst', cfg.run_gate_subst, lambda: prune_gate_substitution(circuit, evaluator, cfg)),
2947
  ('symmetry', cfg.run_symmetry, lambda: prune_symmetry(circuit, evaluator, cfg)),
2948
  ('fanin', cfg.run_fanin, lambda: prune_fanin(circuit, evaluator, cfg)),
2949
+ ('exhaustive_mag', cfg.run_exhaustive_mag, lambda: prune_exhaustive_mag(circuit, evaluator, cfg)),
2950
+ ('exhaustive_sparse', cfg.run_exhaustive_sparse, lambda: prune_exhaustive_sparse(circuit, evaluator, cfg)),
2951
  ('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
2952
  ('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
 
2953
  ]
2954
 
2955
  enabled_methods = [(name, fn) for name, enabled, fn in methods if enabled]
 
2995
  reduction = 1 - best_mag / original['magnitude']
2996
  print(f"\n BEST: {best_method} ({reduction * 100:.1f}% magnitude reduction)")
2997
 
 
 
 
 
2998
  return results
2999
 
3000
 
 
3119
  parser.add_argument('--sa-iters', type=int, default=50000, help='Simulated annealing iterations')
3120
  parser.add_argument('--sa-chains', type=int, default=0, help='Parallel SA chains (0=auto)')
3121
  parser.add_argument('--vram-target', type=float, default=0.75)
 
3122
  parser.add_argument('--fanin-target', type=int, default=4)
3123
+ parser.add_argument('--sparse-max-weight', type=int, default=3, help='Max weight magnitude for sparse search')
3124
+ parser.add_argument('--exhaustive-max-params', type=int, default=12, help='Max params for exhaustive search')
3125
 
3126
  args = parser.parse_args()
3127
 
 
3143
  evo_generations=args.evo_gens,
3144
  annealing_iterations=args.sa_iters,
3145
  annealing_parallel_chains=args.sa_chains,
3146
+ fanin_target=args.fanin_target,
3147
+ sparse_max_weight=args.sparse_max_weight,
3148
+ exhaustive_max_params=args.exhaustive_max_params
3149
  )
3150
 
3151
  if args.methods:
3152
  all_methods = ['magnitude', 'zero', 'quantize', 'evolutionary', 'annealing',
3153
+ 'structural', 'topology', 'sensitivity', 'weight_sharing',
3154
+ 'depth', 'gate_subst', 'symmetry', 'fanin',
3155
+ 'exhaustive_mag', 'exhaustive_sparse']
3156
  for m in all_methods:
3157
  setattr(cfg, f'run_{m}', False)
3158
 
 
3164
  'quant': 'quantize', 'quantize': 'quantize',
3165
  'evo': 'evolutionary', 'evolutionary': 'evolutionary',
3166
  'anneal': 'annealing', 'annealing': 'annealing', 'sa': 'annealing',
3167
+ 'structural': 'structural', 'struct': 'structural',
 
3168
  'topo': 'topology', 'topology': 'topology',
 
3169
  'sens': 'sensitivity', 'sensitivity': 'sensitivity',
3170
  'share': 'weight_sharing', 'weight_sharing': 'weight_sharing', 'sharing': 'weight_sharing',
 
 
3171
  'depth': 'depth',
3172
  'gate': 'gate_subst', 'gate_subst': 'gate_subst', 'subst': 'gate_subst',
3173
  'sym': 'symmetry', 'symmetry': 'symmetry',
3174
  'fanin': 'fanin', 'fan': 'fanin',
3175
+ 'exhaustive_mag': 'exhaustive_mag', 'exh_mag': 'exhaustive_mag', 'exh': 'exhaustive_mag', 'brute': 'exhaustive_mag',
3176
+ 'exhaustive_sparse': 'exhaustive_sparse', 'exh_sparse': 'exhaustive_sparse', 'sparse': 'exhaustive_sparse'
3177
  }
3178
  if m in method_map:
3179
  setattr(cfg, f'run_{method_map[m]}', True)
 
3233
  print(" python prune.py threshold-hamming74decoder --methods evo")
3234
  print(" python prune.py threshold-xor --methods evo --evo-pop 500000 --evo-gens 5000")
3235
  print("")
3236
+ print(" # Exhaustive search (provably optimal):")
3237
+ print(" python prune.py threshold-xor --methods exh_mag # minimize magnitude")
3238
+ print(" python prune.py threshold-xor --methods exh_sparse # minimize nonzeros")
3239
+ print(" python prune.py threshold-mux --methods sparse --sparse-max-weight 2")
3240
+ print("")
3241
  print(" # Pipeline mode (chained, each stage feeds into next):")
3242
  print(" python prune.py threshold-hamming74decoder --pipeline evo,mag,zero,quant --save")
3243
  print(" python prune.py threshold-xor --pipeline anneal,mag,zero --sa-iters 100000")