CharlesCNorton commited on
Commit
e69d4eb
·
0 Parent(s):

Initial commit: threshold circuit pruning framework

Browse files

7 pruning methods: magnitude, batched, zero, quantize, evolutionary, annealing, pareto

Files changed (2) hide show
  1. README.md +59 -0
  2. prune.py +1162 -0
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Threshold Pruner
2
+
3
+ Multi-method pruning framework for threshold logic circuits.
4
+
5
+ ## Methods
6
+
7
+ | Method | Flag | Description |
8
+ |--------|------|-------------|
9
+ | Magnitude Reduction | `mag` | Reduce weights by 1 toward zero |
10
+ | Batched Magnitude | `batched` | GPU-parallel magnitude reduction |
11
+ | Zero Pruning | `zero` | Set weights directly to 0 |
12
+ | Quantization | `quant` | Force weights to {-1, 0, 1} |
13
+ | Evolutionary | `evo` | Mutation + selection with parsimony |
14
+ | Simulated Annealing | `anneal` | Gradual cooling search |
15
+ | Pareto Search | `pareto` | Correctness vs size tradeoff |
16
+
17
+ ## Usage
18
+
19
+ ```bash
20
+ # List available circuits
21
+ python prune.py --list
22
+
23
+ # Prune a circuit with all methods
24
+ python prune.py threshold-hamming74decoder
25
+
26
+ # Specific methods only
27
+ python prune.py threshold-hamming74decoder --methods mag,zero,evo
28
+
29
+ # Batch process
30
+ python prune.py --all --max-inputs 8
31
+
32
+ # Save best result
33
+ python prune.py threshold-hamming74decoder --save
34
+ ```
35
+
36
+ ## Requirements
37
+
38
+ ```
39
+ torch
40
+ safetensors
41
+ ```
42
+
43
+ ## Circuit Format
44
+
45
+ Each circuit needs:
46
+ ```
47
+ threshold-{name}/
48
+ ├── model.safetensors # Weights: {layer.weight: [...], layer.bias: [...]}
49
+ ├── model.py # Forward function
50
+ ├── config.json # {inputs, outputs, neurons, layers, parameters}
51
+ ```
52
+
53
+ ## Related
54
+
55
+ - [Threshold Logic Circuits Collection](https://huggingface.co/collections/phanerozoic/threshold-logic-circuits-6972546b096a4384dd9f34ad)
56
+
57
+ ## License
58
+
59
+ MIT
prune.py ADDED
@@ -0,0 +1,1162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified Threshold Circuit Pruning Framework
3
+ ============================================
4
+
5
+ All pruning methods for threshold logic circuits in a single file.
6
+
7
+ Methods:
8
+ 1. Magnitude Reduction (sequential & batched GPU)
9
+ 2. Zero Pruning (sparsification)
10
+ 3. Weight Quantization (force to {-1,0,1})
11
+ 4. Evolutionary Search (mutation + selection)
12
+ 5. Simulated Annealing (gradual cooling)
13
+ 6. Pareto Frontier (correctness vs size tradeoff)
14
+
15
+ Usage:
16
+ python prune.py threshold-hamming74decoder
17
+ python prune.py threshold-hamming74decoder --methods magnitude,zero,evo
18
+ python prune.py --list
19
+ python prune.py --all --max-inputs 8
20
+
21
+ Author: Pruning framework for phanerozoic/threshold-logic-circuits
22
+ """
23
+
24
+ import torch
25
+ import torch.jit
26
+ import json
27
+ import time
28
+ import random
29
+ import argparse
30
+ import importlib.util
31
+ import sys
32
+ from pathlib import Path
33
+ from dataclasses import dataclass, field
34
+ from typing import Dict, List, Tuple, Optional, Callable, Set
35
+ from enum import Enum, auto
36
+ from datetime import datetime
37
+ from safetensors.torch import load_file, save_file
38
+
39
+
40
+ # =============================================================================
41
+ # CONFIGURATION
42
+ # =============================================================================
43
+
44
+ CIRCUITS_PATH = Path('D:/threshold-circuits')
45
+ RESULTS_PATH = CIRCUITS_PATH / 'pruned_results'
46
+
47
+
48
+ @dataclass
49
+ class Config:
50
+ """Global configuration for pruning."""
51
+ device: str = 'cuda'
52
+ fitness_threshold: float = 0.9999
53
+ batch_size: int = 80000
54
+ verbose: bool = True
55
+
56
+ # Method toggles
57
+ run_magnitude: bool = True
58
+ run_batched_magnitude: bool = True
59
+ run_zero: bool = True
60
+ run_quantize: bool = True
61
+ run_evolutionary: bool = True
62
+ run_annealing: bool = True
63
+ run_pareto: bool = True
64
+
65
+ # Method-specific
66
+ magnitude_passes: int = 100
67
+ evo_generations: int = 1000
68
+ evo_pop_size: int = 200
69
+ evo_mutation_rate: float = 0.1
70
+ evo_parsimony: float = 0.001
71
+ annealing_iterations: int = 10000
72
+ annealing_initial_temp: float = 10.0
73
+ annealing_cooling: float = 0.995
74
+ quantize_targets: List[float] = field(default_factory=lambda: [-1.0, 0.0, 1.0])
75
+ pareto_levels: List[float] = field(default_factory=lambda: [1.0, 0.99, 0.95, 0.90])
76
+
77
+
78
+ # =============================================================================
79
+ # CIRCUIT LOADING
80
+ # =============================================================================
81
+
82
+ @dataclass
83
+ class CircuitSpec:
84
+ """Metadata for a threshold circuit."""
85
+ name: str
86
+ path: Path
87
+ inputs: int
88
+ outputs: int
89
+ neurons: int
90
+ layers: int
91
+ parameters: int
92
+ description: str = ""
93
+
94
+
95
+ class Circuit:
96
+ """Threshold logic circuit loaded from safetensors."""
97
+
98
+ def __init__(self, path: Path, device: str = 'cuda'):
99
+ self.path = Path(path)
100
+ self.device = device
101
+ self.spec = self._load_spec()
102
+ self.weights = self._load_weights()
103
+
104
+ def _load_spec(self) -> CircuitSpec:
105
+ with open(self.path / 'config.json') as f:
106
+ cfg = json.load(f)
107
+ return CircuitSpec(
108
+ name=cfg['name'],
109
+ path=self.path,
110
+ inputs=cfg['inputs'],
111
+ outputs=cfg['outputs'],
112
+ neurons=cfg['neurons'],
113
+ layers=cfg['layers'],
114
+ parameters=cfg['parameters'],
115
+ description=cfg.get('description', '')
116
+ )
117
+
118
+ def _load_weights(self) -> Dict[str, torch.Tensor]:
119
+ w = load_file(str(self.path / 'model.safetensors'))
120
+ return {k: v.float().to(self.device) for k, v in w.items()}
121
+
122
+ def clone(self) -> Dict[str, torch.Tensor]:
123
+ return {k: v.clone() for k, v in self.weights.items()}
124
+
125
+ def stats(self, weights: Dict[str, torch.Tensor] = None) -> Dict:
126
+ w = weights or self.weights
127
+ total = sum(t.numel() for t in w.values())
128
+ nonzero = sum((t != 0).sum().item() for t in w.values())
129
+ mag = sum(t.abs().sum().item() for t in w.values())
130
+ maxw = max(t.abs().max().item() for t in w.values())
131
+ unique = set()
132
+ for t in w.values():
133
+ unique.update(t.flatten().tolist())
134
+ return {
135
+ 'total': total,
136
+ 'nonzero': nonzero,
137
+ 'zeros': total - nonzero,
138
+ 'sparsity': 1 - nonzero/total if total else 0,
139
+ 'magnitude': mag,
140
+ 'max_weight': maxw,
141
+ 'unique_count': len(unique),
142
+ 'unique_values': sorted(unique)
143
+ }
144
+
145
+ def save(self, weights: Dict[str, torch.Tensor], suffix: str = 'pruned'):
146
+ path = self.path / f'model_{suffix}.safetensors'
147
+ cpu_w = {k: v.cpu() for k, v in weights.items()}
148
+ save_file(cpu_w, str(path))
149
+ return path
150
+
151
+
152
+ def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]:
153
+ """Find all circuits in the collection."""
154
+ circuits = []
155
+ for d in base.iterdir():
156
+ if d.is_dir() and (d / 'config.json').exists() and (d / 'model.safetensors').exists():
157
+ try:
158
+ c = Circuit(d, device='cpu')
159
+ circuits.append(c.spec)
160
+ except Exception as e:
161
+ print(f"Skip {d.name}: {e}")
162
+ return sorted(circuits, key=lambda x: (x.inputs, x.neurons))
163
+
164
+
165
+ def load_circuit(name: str, device: str = 'cuda') -> Circuit:
166
+ """Load circuit by name."""
167
+ path = CIRCUITS_PATH / name
168
+ if not path.exists():
169
+ path = CIRCUITS_PATH / f'threshold-{name}'
170
+ if not path.exists():
171
+ raise ValueError(f"Circuit not found: {name}")
172
+ return Circuit(path, device)
173
+
174
+
175
+ # =============================================================================
176
+ # GPU UTILITIES
177
+ # =============================================================================
178
+
179
+ def gpu_memory() -> Dict:
180
+ if torch.cuda.is_available():
181
+ return {
182
+ 'allocated': torch.cuda.memory_allocated() / 1e9,
183
+ 'reserved': torch.cuda.memory_reserved() / 1e9,
184
+ 'total': torch.cuda.get_device_properties(0).total_memory / 1e9
185
+ }
186
+ return {'allocated': 0, 'reserved': 0, 'total': 0}
187
+
188
+
189
+ def create_population(weights: Dict[str, torch.Tensor],
190
+ pop_size: int, device: str) -> Dict[str, torch.Tensor]:
191
+ """Replicate weights for batched evaluation."""
192
+ return {
193
+ k: v.unsqueeze(0).expand(pop_size, *v.shape).clone().to(device)
194
+ for k, v in weights.items()
195
+ }
196
+
197
+
198
+ # =============================================================================
199
+ # GENERIC EVALUATOR
200
+ # =============================================================================
201
+
202
+ class Evaluator:
203
+ """
204
+ Generic evaluator for any threshold circuit.
205
+ Builds truth table and tests exhaustively.
206
+ """
207
+
208
+ def __init__(self, circuit: Circuit, forward_fn: Callable):
209
+ self.circuit = circuit
210
+ self.forward_fn = forward_fn
211
+ self.device = circuit.device
212
+ self.n_inputs = circuit.spec.inputs
213
+ self.n_cases = 2 ** self.n_inputs
214
+
215
+ self._build_inputs()
216
+ self._build_expected()
217
+
218
+ def _build_inputs(self):
219
+ """Generate all 2^n input combinations."""
220
+ if self.n_inputs > 20:
221
+ raise ValueError(f"Input space too large: 2^{self.n_inputs}")
222
+
223
+ idx = torch.arange(self.n_cases, device=self.device, dtype=torch.long)
224
+ bits = torch.arange(self.n_inputs, device=self.device, dtype=torch.long)
225
+ self.inputs = ((idx.unsqueeze(1) >> bits) & 1).float()
226
+
227
+ def _build_expected(self):
228
+ """Compute expected outputs using original weights."""
229
+ self.expected = self.forward_fn(self.inputs, self.circuit.weights)
230
+
231
+ def evaluate(self, weights: Dict[str, torch.Tensor]) -> float:
232
+ """Single evaluation: returns fitness 0.0-1.0"""
233
+ outputs = self.forward_fn(self.inputs, weights)
234
+ correct = (outputs == self.expected).all(dim=-1).float().sum()
235
+ return (correct / self.n_cases).item()
236
+
237
+ def evaluate_batch(self, population: Dict[str, torch.Tensor]) -> torch.Tensor:
238
+ """Batch evaluation: returns [pop_size] fitness tensor"""
239
+ pop_size = next(iter(population.values())).shape[0]
240
+ fitness = torch.zeros(pop_size, device=self.device)
241
+
242
+ for i in range(pop_size):
243
+ w = {k: v[i] for k, v in population.items()}
244
+ outputs = self.forward_fn(self.inputs, w)
245
+ correct = (outputs == self.expected).all(dim=-1).float().sum()
246
+ fitness[i] = correct / self.n_cases
247
+
248
+ return fitness
249
+
250
+
251
+ # =============================================================================
252
+ # PRUNING METHODS
253
+ # =============================================================================
254
+
255
+ @dataclass
256
+ class PruneResult:
257
+ """Result from a pruning method."""
258
+ method: str
259
+ original_stats: Dict
260
+ final_stats: Dict
261
+ final_weights: Dict[str, torch.Tensor]
262
+ fitness: float
263
+ reductions: int
264
+ time_seconds: float
265
+ history: List[Dict] = field(default_factory=list)
266
+
267
+
268
+ def get_candidates(weights: Dict[str, torch.Tensor]) -> List[Tuple[str, int, tuple, float]]:
269
+ """Get all non-zero weight positions."""
270
+ candidates = []
271
+ for name, tensor in weights.items():
272
+ flat = tensor.flatten()
273
+ for i in range(len(flat)):
274
+ val = flat[i].item()
275
+ if val != 0:
276
+ candidates.append((name, i, tensor.shape, val))
277
+ return candidates
278
+
279
+
280
+ def apply_reduction(weights: Dict[str, torch.Tensor],
281
+ name: str, idx: int, shape: tuple, old_val: float):
282
+ """Apply magnitude reduction: move weight 1 step toward zero."""
283
+ new_val = old_val - 1 if old_val > 0 else old_val + 1
284
+ flat = weights[name].flatten()
285
+ flat[idx] = new_val
286
+ weights[name] = flat.view(shape)
287
+
288
+
289
+ def revert_reduction(weights: Dict[str, torch.Tensor],
290
+ name: str, idx: int, shape: tuple, old_val: float):
291
+ """Revert a reduction."""
292
+ flat = weights[name].flatten()
293
+ flat[idx] = old_val
294
+ weights[name] = flat.view(shape)
295
+
296
+
297
+ # -----------------------------------------------------------------------------
298
+ # Method 1: Sequential Magnitude Reduction
299
+ # -----------------------------------------------------------------------------
300
+
301
+ def prune_magnitude(weights: Dict[str, torch.Tensor],
302
+ eval_fn: Callable[[Dict], float],
303
+ cfg: Config) -> PruneResult:
304
+ """Reduce weight magnitudes one at a time."""
305
+ start = time.perf_counter()
306
+ weights = {k: v.clone() for k, v in weights.items()}
307
+ original = _stats(weights)
308
+ reductions = 0
309
+ history = []
310
+
311
+ if cfg.verbose:
312
+ print(f" Starting magnitude reduction...")
313
+ print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
314
+
315
+ for pass_num in range(cfg.magnitude_passes):
316
+ candidates = get_candidates(weights)
317
+ if not candidates:
318
+ if cfg.verbose:
319
+ print(f" No candidates remaining at pass {pass_num}")
320
+ break
321
+
322
+ if cfg.verbose:
323
+ print(f" Pass {pass_num}: testing {len(candidates)} candidates...")
324
+
325
+ pass_reductions = 0
326
+ tested = 0
327
+ for name, idx, shape, old_val in candidates:
328
+ apply_reduction(weights, name, idx, shape, old_val)
329
+ tested += 1
330
+
331
+ fitness = eval_fn(weights)
332
+ if fitness >= cfg.fitness_threshold:
333
+ pass_reductions += 1
334
+ reductions += 1
335
+ if cfg.verbose:
336
+ new_val = old_val - 1 if old_val > 0 else old_val + 1
337
+ print(f" ✓ {name}[{idx}]: {old_val} -> {new_val}")
338
+ else:
339
+ revert_reduction(weights, name, idx, shape, old_val)
340
+
341
+ if cfg.verbose and tested % 50 == 0:
342
+ s = _stats(weights)
343
+ print(f" Progress: {tested}/{len(candidates)}, reductions={pass_reductions}, mag={s['magnitude']:.0f}")
344
+
345
+ history.append({'pass': pass_num, 'reductions': pass_reductions})
346
+
347
+ s = _stats(weights)
348
+ if cfg.verbose:
349
+ print(f" Pass {pass_num} complete: +{pass_reductions} reductions, mag={s['magnitude']:.0f}, nonzero={s['nonzero']}")
350
+
351
+ if pass_reductions == 0:
352
+ if cfg.verbose:
353
+ print(f" No progress at pass {pass_num}, stopping.")
354
+ break
355
+
356
+ return PruneResult(
357
+ method='magnitude',
358
+ original_stats=original,
359
+ final_stats=_stats(weights),
360
+ final_weights=weights,
361
+ fitness=eval_fn(weights),
362
+ reductions=reductions,
363
+ time_seconds=time.perf_counter() - start,
364
+ history=history
365
+ )
366
+
367
+
368
+ # -----------------------------------------------------------------------------
369
+ # Method 2: Batched GPU Magnitude Reduction
370
+ # -----------------------------------------------------------------------------
371
+
372
+ def prune_magnitude_batched(weights: Dict[str, torch.Tensor],
373
+ eval_fn: Callable[[Dict], float],
374
+ batch_eval_fn: Callable[[Dict], torch.Tensor],
375
+ cfg: Config) -> PruneResult:
376
+ """GPU-batched magnitude reduction."""
377
+ start = time.perf_counter()
378
+ weights = {k: v.clone() for k, v in weights.items()}
379
+ original = _stats(weights)
380
+ device = cfg.device
381
+ reductions = 0
382
+ history = []
383
+
384
+ for pass_num in range(cfg.magnitude_passes):
385
+ candidates = get_candidates(weights)
386
+ if not candidates:
387
+ break
388
+
389
+ # Phase 1: Batch test all candidates
390
+ successful = []
391
+ n = len(candidates)
392
+
393
+ for batch_start in range(0, n, cfg.batch_size):
394
+ batch = candidates[batch_start:batch_start + cfg.batch_size]
395
+ batch_len = len(batch)
396
+
397
+ pop = {name: tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device)
398
+ for name, tensor in weights.items()}
399
+
400
+ for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
401
+ new_val = old_val - 1 if old_val > 0 else old_val + 1
402
+ flat_view = pop[name][pop_idx].flatten()
403
+ flat_view[flat_idx] = new_val
404
+
405
+ fitness = batch_eval_fn(pop)
406
+
407
+ for pop_idx, cand in enumerate(batch):
408
+ if fitness[pop_idx].item() >= cfg.fitness_threshold:
409
+ successful.append(cand)
410
+
411
+ # Phase 2: Apply with conflict resolution
412
+ pass_reductions = 0
413
+ for name, idx, shape, old_val in successful:
414
+ current_val = weights[name].flatten()[idx].item()
415
+ if current_val == old_val:
416
+ apply_reduction(weights, name, idx, shape, old_val)
417
+ if eval_fn(weights) >= cfg.fitness_threshold:
418
+ pass_reductions += 1
419
+ reductions += 1
420
+ else:
421
+ revert_reduction(weights, name, idx, shape, old_val)
422
+
423
+ history.append({'pass': pass_num, 'reductions': pass_reductions, 'candidates': len(successful)})
424
+
425
+ if cfg.verbose:
426
+ s = _stats(weights)
427
+ print(f" Pass {pass_num}: {pass_reductions}/{len(successful)} applied, mag={s['magnitude']:.0f}")
428
+
429
+ if pass_reductions == 0:
430
+ break
431
+
432
+ return PruneResult(
433
+ method='batched_magnitude',
434
+ original_stats=original,
435
+ final_stats=_stats(weights),
436
+ final_weights=weights,
437
+ fitness=eval_fn(weights),
438
+ reductions=reductions,
439
+ time_seconds=time.perf_counter() - start,
440
+ history=history
441
+ )
442
+
443
+
444
+ # -----------------------------------------------------------------------------
445
+ # Method 3: Zero Pruning
446
+ # -----------------------------------------------------------------------------
447
+
448
+ def prune_zero(weights: Dict[str, torch.Tensor],
449
+ eval_fn: Callable[[Dict], float],
450
+ cfg: Config) -> PruneResult:
451
+ """Try setting weights directly to zero."""
452
+ start = time.perf_counter()
453
+ weights = {k: v.clone() for k, v in weights.items()}
454
+ original = _stats(weights)
455
+
456
+ candidates = get_candidates(weights)
457
+ random.shuffle(candidates)
458
+
459
+ if cfg.verbose:
460
+ print(f" Starting zero pruning...")
461
+ print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
462
+ print(f" Testing {len(candidates)} candidates (random order)...")
463
+
464
+ reductions = 0
465
+ tested = 0
466
+ for name, idx, shape, old_val in candidates:
467
+ flat = weights[name].flatten()
468
+ flat[idx] = 0
469
+ weights[name] = flat.view(shape)
470
+ tested += 1
471
+
472
+ if eval_fn(weights) >= cfg.fitness_threshold:
473
+ reductions += 1
474
+ if cfg.verbose:
475
+ print(f" ✓ {name}[{idx}]: {old_val} -> 0 (zeroed)")
476
+ else:
477
+ flat = weights[name].flatten()
478
+ flat[idx] = old_val
479
+ weights[name] = flat.view(shape)
480
+
481
+ if cfg.verbose and tested % 50 == 0:
482
+ s = _stats(weights)
483
+ print(f" Progress: {tested}/{len(candidates)}, zeroed={reductions}, mag={s['magnitude']:.0f}")
484
+
485
+ if cfg.verbose:
486
+ s = _stats(weights)
487
+ print(f" Zero pruning complete: {reductions} weights zeroed")
488
+ print(f" Final: mag={s['magnitude']:.0f}, nonzero={s['nonzero']}")
489
+
490
+ return PruneResult(
491
+ method='zero',
492
+ original_stats=original,
493
+ final_stats=_stats(weights),
494
+ final_weights=weights,
495
+ fitness=eval_fn(weights),
496
+ reductions=reductions,
497
+ time_seconds=time.perf_counter() - start
498
+ )
499
+
500
+
501
+ # -----------------------------------------------------------------------------
502
+ # Method 4: Quantization
503
+ # -----------------------------------------------------------------------------
504
+
505
+ def prune_quantize(weights: Dict[str, torch.Tensor],
506
+ eval_fn: Callable[[Dict], float],
507
+ cfg: Config) -> PruneResult:
508
+ """Force weights to target set (default: {-1,0,1})."""
509
+ start = time.perf_counter()
510
+ weights = {k: v.clone() for k, v in weights.items()}
511
+ original = _stats(weights)
512
+ target = torch.tensor(cfg.quantize_targets, device=weights[next(iter(weights))].device)
513
+ target_set = set(cfg.quantize_targets)
514
+
515
+ if cfg.verbose:
516
+ print(f" Starting quantization...")
517
+ print(f" Target values: {sorted(cfg.quantize_targets)}")
518
+ print(f" Original unique values: {original.get('unique_count', len(set(v.item() for t in weights.values() for v in t.flatten())))}")
519
+ print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
520
+
521
+ # Count how many need quantizing
522
+ needs_quant = sum(1 for t in weights.values() for v in t.flatten() if v.item() not in target_set)
523
+ if cfg.verbose:
524
+ print(f" Weights needing quantization: {needs_quant}")
525
+
526
+ reductions = 0
527
+ tested = 0
528
+ for name, tensor in list(weights.items()):
529
+ flat = tensor.flatten()
530
+ for i in range(len(flat)):
531
+ old_val = flat[i].item()
532
+ if old_val not in target_set:
533
+ distances = (target - old_val).abs()
534
+ closest = target[distances.argmin()].item()
535
+
536
+ flat[i] = closest
537
+ weights[name] = flat.view(tensor.shape)
538
+ tested += 1
539
+
540
+ if eval_fn(weights) >= cfg.fitness_threshold:
541
+ reductions += 1
542
+ if cfg.verbose:
543
+ print(f" ✓ {name}[{i}]: {old_val} -> {closest}")
544
+ else:
545
+ flat[i] = old_val
546
+ weights[name] = flat.view(tensor.shape)
547
+
548
+ if cfg.verbose and tested % 20 == 0:
549
+ print(f" Progress: {tested}/{needs_quant}, quantized={reductions}")
550
+
551
+ if cfg.verbose:
552
+ s = _stats(weights)
553
+ unique_now = len(set(v.item() for t in weights.values() for v in t.flatten()))
554
+ print(f" Quantization complete: {reductions}/{tested} quantized")
555
+ print(f" Final unique values: {unique_now}")
556
+ print(f" Final: mag={s['magnitude']:.0f}, nonzero={s['nonzero']}")
557
+
558
+ return PruneResult(
559
+ method='quantize',
560
+ original_stats=original,
561
+ final_stats=_stats(weights),
562
+ final_weights=weights,
563
+ fitness=eval_fn(weights),
564
+ reductions=reductions,
565
+ time_seconds=time.perf_counter() - start
566
+ )
567
+
568
+
569
+ # -----------------------------------------------------------------------------
570
+ # Method 5: Evolutionary Search
571
+ # -----------------------------------------------------------------------------
572
+
573
+ def prune_evolutionary(weights: Dict[str, torch.Tensor],
574
+ batch_eval_fn: Callable[[Dict], torch.Tensor],
575
+ cfg: Config) -> PruneResult:
576
+ """Evolutionary search with parsimony pressure."""
577
+ start = time.perf_counter()
578
+ original = _stats(weights)
579
+ device = cfg.device
580
+ pop_size = cfg.evo_pop_size
581
+
582
+ if cfg.verbose:
583
+ print(f" Starting evolutionary search...")
584
+ print(f" Population: {pop_size}, Generations: {cfg.evo_generations}")
585
+ print(f" Mutation rate: {cfg.evo_mutation_rate}, Parsimony: {cfg.evo_parsimony}")
586
+ print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
587
+
588
+ # Initialize population
589
+ pop = {k: v.unsqueeze(0).expand(pop_size, *v.shape).clone().to(device)
590
+ for k, v in weights.items()}
591
+
592
+ best_weights = {k: v.clone() for k, v in weights.items()}
593
+ best_score = -float('inf')
594
+ best_fitness = 0.0
595
+ history = []
596
+ improved_at = 0
597
+
598
+ for gen in range(cfg.evo_generations):
599
+ # Evaluate
600
+ fitness = batch_eval_fn(pop)
601
+
602
+ # Compute magnitude penalty
603
+ mags = torch.stack([
604
+ sum(pop[name][i].abs().sum() for name in pop)
605
+ for i in range(pop_size)
606
+ ])
607
+ adjusted = fitness - cfg.evo_parsimony * mags
608
+
609
+ # Track best
610
+ best_idx = adjusted.argmax().item()
611
+ gen_best_fitness = fitness[best_idx].item()
612
+ gen_best_adj = adjusted[best_idx].item()
613
+ gen_best_mag = mags[best_idx].item()
614
+
615
+ if gen_best_fitness >= cfg.fitness_threshold:
616
+ if gen_best_adj > best_score:
617
+ best_score = gen_best_adj
618
+ best_fitness = gen_best_fitness
619
+ best_weights = {k: v[best_idx].clone() for k, v in pop.items()}
620
+ improved_at = gen
621
+ if cfg.verbose:
622
+ s = _stats(best_weights)
623
+ print(f" Gen {gen}: NEW BEST! score={best_score:.4f}, fitness={best_fitness:.4f}, mag={s['magnitude']:.0f}")
624
+
625
+ # Stats for this generation
626
+ valid_mask = fitness >= cfg.fitness_threshold
627
+ n_valid = valid_mask.sum().item()
628
+ avg_fitness = fitness.mean().item()
629
+ avg_mag = mags.mean().item()
630
+
631
+ if gen % 50 == 0:
632
+ s = _stats(best_weights)
633
+ if cfg.verbose:
634
+ print(f" Gen {gen}: valid={n_valid}/{pop_size}, avg_fit={avg_fitness:.4f}, avg_mag={avg_mag:.0f}, best_mag={s['magnitude']:.0f}")
635
+ history.append({'gen': gen, 'score': best_score, 'mag': s['magnitude'], 'n_valid': n_valid})
636
+
637
+ # Selection + mutation
638
+ probs = torch.softmax(adjusted, dim=0)
639
+ indices = torch.multinomial(probs, pop_size, replacement=True)
640
+
641
+ new_pop = {}
642
+ for name, tensor in pop.items():
643
+ selected = tensor[indices].clone()
644
+ mask = torch.rand_like(selected) < cfg.evo_mutation_rate
645
+ mutations = torch.randint(-1, 2, selected.shape, device=device).float()
646
+ selected = selected + mask.float() * mutations
647
+ new_pop[name] = selected
648
+ pop = new_pop
649
+
650
+ # Final report
651
+ final_stats = _stats(best_weights)
652
+ elapsed = time.perf_counter() - start
653
+
654
+ if cfg.verbose:
655
+ print(f" Evolution complete in {elapsed:.1f}s")
656
+ print(f" Best found at generation {improved_at}")
657
+ print(f" Final: mag={final_stats['magnitude']:.0f}, nonzero={final_stats['nonzero']}")
658
+ reduction_pct = 100 * (1 - final_stats['magnitude'] / original['magnitude'])
659
+ print(f" Magnitude reduction: {reduction_pct:.1f}%")
660
+
661
+ return PruneResult(
662
+ method='evolutionary',
663
+ original_stats=original,
664
+ final_stats=final_stats,
665
+ final_weights=best_weights,
666
+ fitness=best_score + cfg.evo_parsimony * final_stats['magnitude'],
667
+ reductions=int(original['magnitude'] - final_stats['magnitude']),
668
+ time_seconds=elapsed,
669
+ history=history
670
+ )
671
+
672
+
673
+ # -----------------------------------------------------------------------------
674
+ # Method 6: Simulated Annealing
675
+ # -----------------------------------------------------------------------------
676
+
677
+ def prune_annealing(weights: Dict[str, torch.Tensor],
678
+ eval_fn: Callable[[Dict], float],
679
+ cfg: Config) -> PruneResult:
680
+ """Simulated annealing for circuit minimization."""
681
+ start = time.perf_counter()
682
+ weights = {k: v.clone() for k, v in weights.items()}
683
+ original = _stats(weights)
684
+
685
+ current = weights
686
+ current_energy = _energy(current, eval_fn, cfg)
687
+ best = {k: v.clone() for k, v in current.items()}
688
+ best_energy = current_energy
689
+
690
+ temp = cfg.annealing_initial_temp
691
+ history = []
692
+
693
+ for i in range(cfg.annealing_iterations):
694
+ # Perturb
695
+ neighbor = {k: v.clone() for k, v in current.items()}
696
+ name = random.choice(list(neighbor.keys()))
697
+ flat = neighbor[name].flatten()
698
+ idx = random.randint(0, len(flat) - 1)
699
+ mutation = random.choice([-1, 1, 0])
700
+ if mutation == 0:
701
+ flat[idx] = 0
702
+ else:
703
+ flat[idx] = flat[idx] + mutation
704
+ neighbor[name] = flat.view(neighbor[name].shape)
705
+
706
+ neighbor_energy = _energy(neighbor, eval_fn, cfg)
707
+ delta = neighbor_energy - current_energy
708
+
709
+ if delta < 0 or random.random() < math.exp(-delta / max(temp, 1e-10)):
710
+ current = neighbor
711
+ current_energy = neighbor_energy
712
+
713
+ if current_energy < best_energy:
714
+ if eval_fn(current) >= cfg.fitness_threshold:
715
+ best = {k: v.clone() for k, v in current.items()}
716
+ best_energy = current_energy
717
+
718
+ temp *= cfg.annealing_cooling
719
+
720
+ if i % 1000 == 0:
721
+ s = _stats(best)
722
+ if cfg.verbose:
723
+ print(f" Iter {i}: temp={temp:.4f}, mag={s['magnitude']:.0f}")
724
+ history.append({'iter': i, 'temp': temp, 'mag': s['magnitude']})
725
+
726
+ return PruneResult(
727
+ method='annealing',
728
+ original_stats=original,
729
+ final_stats=_stats(best),
730
+ final_weights=best,
731
+ fitness=eval_fn(best),
732
+ reductions=int(original['magnitude'] - _stats(best)['magnitude']),
733
+ time_seconds=time.perf_counter() - start,
734
+ history=history
735
+ )
736
+
737
+
738
+ def _energy(weights, eval_fn, cfg):
739
+ fitness = eval_fn(weights)
740
+ mag = sum(t.abs().sum().item() for t in weights.values())
741
+ if fitness < cfg.fitness_threshold:
742
+ return 1e6 + mag
743
+ return mag
744
+
745
+
746
+ # -----------------------------------------------------------------------------
747
+ # Method 7: Pareto Frontier
748
+ # -----------------------------------------------------------------------------
749
+
750
+ def prune_pareto(weights: Dict[str, torch.Tensor],
751
+ eval_fn: Callable[[Dict], float],
752
+ cfg: Config) -> PruneResult:
753
+ """Search Pareto frontier of correctness vs size."""
754
+ start = time.perf_counter()
755
+ original = _stats(weights)
756
+ frontier = []
757
+
758
+ for target in cfg.pareto_levels:
759
+ if cfg.verbose:
760
+ print(f" Target fitness >= {target}")
761
+
762
+ relaxed_cfg = Config(
763
+ device=cfg.device,
764
+ fitness_threshold=target,
765
+ magnitude_passes=50,
766
+ verbose=False
767
+ )
768
+
769
+ result = prune_magnitude({k: v.clone() for k, v in weights.items()}, eval_fn, relaxed_cfg)
770
+
771
+ frontier.append({
772
+ 'target': target,
773
+ 'actual': result.fitness,
774
+ 'magnitude': result.final_stats['magnitude'],
775
+ 'nonzero': result.final_stats['nonzero']
776
+ })
777
+
778
+ if cfg.verbose:
779
+ print(f" -> fitness={result.fitness:.4f}, mag={result.final_stats['magnitude']:.0f}")
780
+
781
+ return PruneResult(
782
+ method='pareto',
783
+ original_stats=original,
784
+ final_stats=frontier[-1] if frontier else original,
785
+ final_weights=weights,
786
+ fitness=frontier[0]['actual'] if frontier else 1.0,
787
+ reductions=len(frontier),
788
+ time_seconds=time.perf_counter() - start,
789
+ history=frontier
790
+ )
791
+
792
+
793
+ # -----------------------------------------------------------------------------
794
+ # Helpers
795
+ # -----------------------------------------------------------------------------
796
+
797
+ def _stats(weights: Dict[str, torch.Tensor]) -> Dict:
798
+ total = sum(t.numel() for t in weights.values())
799
+ nonzero = sum((t != 0).sum().item() for t in weights.values())
800
+ mag = sum(t.abs().sum().item() for t in weights.values())
801
+ maxw = max(t.abs().max().item() for t in weights.values()) if weights else 0
802
+ return {'total': total, 'nonzero': nonzero, 'magnitude': mag, 'max': maxw}
803
+
804
+
805
+ import math
806
+
807
+
808
+ # =============================================================================
809
+ # CIRCUIT-SPECIFIC FORWARD FUNCTIONS
810
+ # =============================================================================
811
+
812
+ def make_hamming_decoder_forward(device='cuda'):
813
+ """Create forward function for Hamming(7,4) decoder."""
814
+
815
+ def forward(inputs, weights):
816
+ """
817
+ Batched forward pass for Hamming decoder.
818
+ inputs: [n_cases, 7]
819
+ weights: dict of weight tensors
820
+ Returns: [n_cases, 4]
821
+ """
822
+ n_cases = inputs.shape[0]
823
+ w = weights
824
+ outputs = []
825
+
826
+ for case_idx in range(n_cases):
827
+ c = [inputs[case_idx, i].item() for i in range(7)]
828
+
829
+ def xor2(a, b, prefix):
830
+ inp = torch.tensor([float(a), float(b)], device=device)
831
+ or_out = float((inp * w[f'{prefix}.layer1.or.weight'].flatten()[:2]).sum() +
832
+ w[f'{prefix}.layer1.or.bias'].squeeze() >= 0)
833
+ nand_out = float((inp * w[f'{prefix}.layer1.nand.weight'].flatten()[:2]).sum() +
834
+ w[f'{prefix}.layer1.nand.bias'].squeeze() >= 0)
835
+ l1 = torch.tensor([or_out, nand_out], device=device)
836
+ return int((l1 * w[f'{prefix}.layer2.weight'].flatten()).sum() +
837
+ w[f'{prefix}.layer2.bias'].squeeze() >= 0)
838
+
839
+ def xor4(indices, prefix):
840
+ i0, i1, i2, i3 = indices
841
+ inp = torch.tensor([float(c[i]) for i in range(7)], device=device)
842
+
843
+ or_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.or.weight'].flatten()).sum() +
844
+ w[f'{prefix}.xor_{i0}{i1}.layer1.or.bias'].squeeze() >= 0)
845
+ nand_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.nand.weight'].flatten()).sum() +
846
+ w[f'{prefix}.xor_{i0}{i1}.layer1.nand.bias'].squeeze() >= 0)
847
+ xor_ab = int((torch.tensor([or_out, nand_out], device=device) *
848
+ w[f'{prefix}.xor_{i0}{i1}.layer2.weight'].flatten()).sum() +
849
+ w[f'{prefix}.xor_{i0}{i1}.layer2.bias'].squeeze() >= 0)
850
+
851
+ or_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.or.weight'].flatten()).sum() +
852
+ w[f'{prefix}.xor_{i2}{i3}.layer1.or.bias'].squeeze() >= 0)
853
+ nand_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.nand.weight'].flatten()).sum() +
854
+ w[f'{prefix}.xor_{i2}{i3}.layer1.nand.bias'].squeeze() >= 0)
855
+ xor_cd = int((torch.tensor([or_out, nand_out], device=device) *
856
+ w[f'{prefix}.xor_{i2}{i3}.layer2.weight'].flatten()).sum() +
857
+ w[f'{prefix}.xor_{i2}{i3}.layer2.bias'].squeeze() >= 0)
858
+
859
+ inp2 = torch.tensor([float(xor_ab), float(xor_cd)], device=device)
860
+ or_out = float((inp2 * w[f'{prefix}.xor_final.layer1.or.weight'].flatten()).sum() +
861
+ w[f'{prefix}.xor_final.layer1.or.bias'].squeeze() >= 0)
862
+ nand_out = float((inp2 * w[f'{prefix}.xor_final.layer1.nand.weight'].flatten()).sum() +
863
+ w[f'{prefix}.xor_final.layer1.nand.bias'].squeeze() >= 0)
864
+ return int((torch.tensor([or_out, nand_out], device=device) *
865
+ w[f'{prefix}.xor_final.layer2.weight'].flatten()).sum() +
866
+ w[f'{prefix}.xor_final.layer2.bias'].squeeze() >= 0)
867
+
868
+ s1 = xor4([0, 2, 4, 6], 's1')
869
+ s2 = xor4([1, 2, 5, 6], 's2')
870
+ s3 = xor4([3, 4, 5, 6], 's3')
871
+
872
+ syndrome = torch.tensor([float(s1), float(s2), float(s3)], device=device)
873
+
874
+ flip3 = int((syndrome * w['flip3.weight'].flatten()).sum() + w['flip3.bias'].squeeze() >= 0)
875
+ flip5 = int((syndrome * w['flip5.weight'].flatten()).sum() + w['flip5.bias'].squeeze() >= 0)
876
+ flip6 = int((syndrome * w['flip6.weight'].flatten()).sum() + w['flip6.bias'].squeeze() >= 0)
877
+ flip7 = int((syndrome * w['flip7.weight'].flatten()).sum() + w['flip7.bias'].squeeze() >= 0)
878
+
879
+ d1 = xor2(c[2], flip3, 'd1.xor')
880
+ d2 = xor2(c[4], flip5, 'd2.xor')
881
+ d3 = xor2(c[5], flip6, 'd3.xor')
882
+ d4 = xor2(c[6], flip7, 'd4.xor')
883
+
884
+ outputs.append([d1, d2, d3, d4])
885
+
886
+ return torch.tensor(outputs, device=device, dtype=torch.float32)
887
+
888
+ # Build test cases with error injection
889
+ def hamming_encode(data):
890
+ d1, d2, d3, d4 = (data >> 0) & 1, (data >> 1) & 1, (data >> 2) & 1, (data >> 3) & 1
891
+ p1, p2, p3 = d1 ^ d2 ^ d4, d1 ^ d3 ^ d4, d2 ^ d3 ^ d4
892
+ return (p1 << 0) | (p2 << 1) | (d1 << 2) | (p3 << 3) | (d2 << 4) | (d3 << 5) | (d4 << 6)
893
+
894
+ inputs_list, expected_list = [], []
895
+ for data in range(16):
896
+ cw = hamming_encode(data)
897
+ inputs_list.append([(cw >> i) & 1 for i in range(7)])
898
+ expected_list.append([(data >> i) & 1 for i in range(4)])
899
+ for data in range(16):
900
+ cw = hamming_encode(data)
901
+ for flip in range(7):
902
+ corrupted = cw ^ (1 << flip)
903
+ inputs_list.append([(corrupted >> i) & 1 for i in range(7)])
904
+ expected_list.append([(data >> i) & 1 for i in range(4)])
905
+
906
+ test_inputs = torch.tensor(inputs_list, device=device, dtype=torch.float32)
907
+ test_expected = torch.tensor(expected_list, device=device, dtype=torch.float32)
908
+
909
+ return forward, test_inputs, test_expected
910
+
911
+
912
+ def make_generic_forward(circuit: Circuit):
913
+ """Create generic forward by loading model.py dynamically."""
914
+ model_py = circuit.path / 'model.py'
915
+ if not model_py.exists():
916
+ return None, None, None
917
+
918
+ spec = importlib.util.spec_from_file_location("circuit_model", model_py)
919
+ module = importlib.util.module_from_spec(spec)
920
+ spec.loader.exec_module(module)
921
+
922
+ # Find the main function
923
+ fn_names = [circuit.spec.name.replace('threshold-', '').replace('-', '_'),
924
+ 'forward', 'evaluate', 'run']
925
+
926
+ main_fn = None
927
+ for name in dir(module):
928
+ if name.lower() in [n.lower() for n in fn_names] and callable(getattr(module, name)):
929
+ main_fn = getattr(module, name)
930
+ break
931
+
932
+ if main_fn is None:
933
+ return None, None, None
934
+
935
+ # Build inputs
936
+ n = circuit.spec.inputs
937
+ n_cases = 2 ** n
938
+ device = circuit.device
939
+
940
+ idx = torch.arange(n_cases, device=device, dtype=torch.long)
941
+ bits = torch.arange(n, device=device, dtype=torch.long)
942
+ inputs = ((idx.unsqueeze(1) >> bits) & 1).float()
943
+
944
+ # Compute expected
945
+ outputs = []
946
+ for i in range(n_cases):
947
+ args = [int(inputs[i, j].item()) for j in range(n)]
948
+ result = main_fn(*args, circuit.weights)
949
+ if isinstance(result, (list, tuple)):
950
+ outputs.append([float(x) for x in result])
951
+ else:
952
+ outputs.append([float(result)])
953
+ expected = torch.tensor(outputs, device=device, dtype=torch.float32)
954
+
955
+ def forward(inp, weights):
956
+ out = []
957
+ for i in range(inp.shape[0]):
958
+ args = [int(inp[i, j].item()) for j in range(n)]
959
+ result = main_fn(*args, weights)
960
+ if isinstance(result, (list, tuple)):
961
+ out.append([float(x) for x in result])
962
+ else:
963
+ out.append([float(result)])
964
+ return torch.tensor(out, device=device, dtype=torch.float32)
965
+
966
+ return forward, inputs, expected
967
+
968
+
969
+ # =============================================================================
970
+ # MAIN ORCHESTRATOR
971
+ # =============================================================================
972
+
973
+ def run_all_methods(circuit: Circuit, cfg: Config) -> Dict[str, PruneResult]:
974
+ """Run all enabled pruning methods on a circuit."""
975
+
976
+ print(f"\n{'='*70}")
977
+ print(f" PRUNING: {circuit.spec.name}")
978
+ print(f"{'='*70}")
979
+
980
+ original = circuit.stats()
981
+ print(f" Inputs: {circuit.spec.inputs}, Outputs: {circuit.spec.outputs}")
982
+ print(f" Neurons: {circuit.spec.neurons}, Layers: {circuit.spec.layers}")
983
+ print(f" Parameters: {original['total']}, Non-zero: {original['nonzero']}")
984
+ print(f" Magnitude: {original['magnitude']:.0f}")
985
+ print(f"{'='*70}")
986
+
987
+ # Get forward function
988
+ if 'hamming74decoder' in circuit.spec.name:
989
+ forward_fn, test_inputs, test_expected = make_hamming_decoder_forward(cfg.device)
990
+ else:
991
+ forward_fn, test_inputs, test_expected = make_generic_forward(circuit)
992
+
993
+ if forward_fn is None:
994
+ print("ERROR: Could not create forward function")
995
+ return {}
996
+
997
+ # Create evaluators
998
+ def eval_fn(weights):
999
+ outputs = forward_fn(test_inputs, weights)
1000
+ correct = (outputs == test_expected).all(dim=-1).float().sum()
1001
+ return (correct / test_inputs.shape[0]).item()
1002
+
1003
+ def batch_eval_fn(population):
1004
+ pop_size = next(iter(population.values())).shape[0]
1005
+ fitness = torch.zeros(pop_size, device=cfg.device)
1006
+ for i in range(pop_size):
1007
+ w = {k: v[i] for k, v in population.items()}
1008
+ outputs = forward_fn(test_inputs, w)
1009
+ correct = (outputs == test_expected).all(dim=-1).float().sum()
1010
+ fitness[i] = correct / test_inputs.shape[0]
1011
+ return fitness
1012
+
1013
+ # Verify initial
1014
+ initial = eval_fn(circuit.weights)
1015
+ print(f"\n Initial fitness: {initial:.6f}")
1016
+ if initial < cfg.fitness_threshold:
1017
+ print(" ERROR: Circuit doesn't pass baseline!")
1018
+ return {}
1019
+
1020
+ results = {}
1021
+
1022
+ # Run methods
1023
+ if cfg.run_magnitude:
1024
+ print(f"\n[1] MAGNITUDE REDUCTION (sequential)")
1025
+ results['magnitude'] = prune_magnitude(circuit.clone(), eval_fn, cfg)
1026
+ _print_result(results['magnitude'])
1027
+
1028
+ if cfg.run_batched_magnitude:
1029
+ print(f"\n[2] MAGNITUDE REDUCTION (batched GPU)")
1030
+ results['batched'] = prune_magnitude_batched(circuit.clone(), eval_fn, batch_eval_fn, cfg)
1031
+ _print_result(results['batched'])
1032
+
1033
+ if cfg.run_zero:
1034
+ print(f"\n[3] ZERO PRUNING")
1035
+ results['zero'] = prune_zero(circuit.clone(), eval_fn, cfg)
1036
+ _print_result(results['zero'])
1037
+
1038
+ if cfg.run_quantize:
1039
+ print(f"\n[4] QUANTIZATION")
1040
+ results['quantize'] = prune_quantize(circuit.clone(), eval_fn, cfg)
1041
+ _print_result(results['quantize'])
1042
+
1043
+ if cfg.run_evolutionary:
1044
+ print(f"\n[5] EVOLUTIONARY")
1045
+ results['evolutionary'] = prune_evolutionary(circuit.clone(), batch_eval_fn, cfg)
1046
+ _print_result(results['evolutionary'])
1047
+
1048
+ if cfg.run_annealing:
1049
+ print(f"\n[6] SIMULATED ANNEALING")
1050
+ results['annealing'] = prune_annealing(circuit.clone(), eval_fn, cfg)
1051
+ _print_result(results['annealing'])
1052
+
1053
+ if cfg.run_pareto:
1054
+ print(f"\n[7] PARETO FRONTIER")
1055
+ results['pareto'] = prune_pareto(circuit.clone(), eval_fn, cfg)
1056
+ _print_result(results['pareto'])
1057
+
1058
+ # Summary
1059
+ print(f"\n{'='*70}")
1060
+ print(" SUMMARY")
1061
+ print(f"{'='*70}")
1062
+ print(f"\n{'Method':<20} {'Fitness':<10} {'Magnitude':<12} {'Nonzero':<10} {'Time':<10}")
1063
+ print("-" * 70)
1064
+ print(f"{'Original':<20} {'1.0000':<10} {original['magnitude']:<12.0f} {original['nonzero']:<10} {'-':<10}")
1065
+
1066
+ best_method, best_mag = None, float('inf')
1067
+ for name, r in sorted(results.items(), key=lambda x: x[1].final_stats.get('magnitude', float('inf'))):
1068
+ mag = r.final_stats.get('magnitude', 0)
1069
+ nz = r.final_stats.get('nonzero', 0)
1070
+ print(f"{name:<20} {r.fitness:<10.4f} {mag:<12.0f} {nz:<10} {r.time_seconds:<10.1f}s")
1071
+ if r.fitness >= cfg.fitness_threshold and mag < best_mag:
1072
+ best_mag = mag
1073
+ best_method = name
1074
+
1075
+ if best_method:
1076
+ reduction = 1 - best_mag / original['magnitude']
1077
+ print(f"\n BEST: {best_method} ({reduction*100:.1f}% magnitude reduction)")
1078
+
1079
+ return results
1080
+
1081
+
1082
+ def _print_result(r: PruneResult):
1083
+ print(f" Fitness: {r.fitness:.6f}")
1084
+ print(f" Magnitude: {r.final_stats.get('magnitude', 0):.0f}")
1085
+ print(f" Nonzero: {r.final_stats.get('nonzero', 0)}")
1086
+ print(f" Time: {r.time_seconds:.1f}s")
1087
+
1088
+
1089
+ # =============================================================================
1090
+ # CLI
1091
+ # =============================================================================
1092
+
1093
+ def main():
1094
+ parser = argparse.ArgumentParser(description='Prune threshold circuits')
1095
+ parser.add_argument('circuit', nargs='?', help='Circuit name')
1096
+ parser.add_argument('--list', action='store_true', help='List available circuits')
1097
+ parser.add_argument('--all', action='store_true', help='Run on all circuits')
1098
+ parser.add_argument('--max-inputs', type=int, default=10, help='Max inputs for --all')
1099
+ parser.add_argument('--device', default='cuda', help='cuda or cpu')
1100
+ parser.add_argument('--batch-size', type=int, default=80000)
1101
+ parser.add_argument('--methods', type=str, help='Comma-separated methods')
1102
+ parser.add_argument('--fitness', type=float, default=0.9999)
1103
+ parser.add_argument('--quiet', action='store_true')
1104
+ parser.add_argument('--save', action='store_true', help='Save best result')
1105
+
1106
+ args = parser.parse_args()
1107
+
1108
+ if args.list:
1109
+ specs = discover_circuits()
1110
+ print(f"\nAvailable circuits ({len(specs)}):\n")
1111
+ for s in specs:
1112
+ print(f" {s.name:<40} {s.inputs}in/{s.outputs}out {s.neurons}N {s.layers}L")
1113
+ return
1114
+
1115
+ cfg = Config(
1116
+ device=args.device,
1117
+ batch_size=args.batch_size,
1118
+ fitness_threshold=args.fitness,
1119
+ verbose=not args.quiet
1120
+ )
1121
+
1122
+ if args.methods:
1123
+ methods = args.methods.lower().split(',')
1124
+ cfg.run_magnitude = 'magnitude' in methods or 'mag' in methods
1125
+ cfg.run_batched_magnitude = 'batched' in methods or 'batch' in methods
1126
+ cfg.run_zero = 'zero' in methods
1127
+ cfg.run_quantize = 'quantize' in methods or 'quant' in methods
1128
+ cfg.run_evolutionary = 'evo' in methods or 'evolutionary' in methods
1129
+ cfg.run_annealing = 'anneal' in methods or 'sa' in methods
1130
+ cfg.run_pareto = 'pareto' in methods
1131
+
1132
+ RESULTS_PATH.mkdir(exist_ok=True)
1133
+
1134
+ if args.all:
1135
+ specs = [s for s in discover_circuits() if s.inputs <= args.max_inputs]
1136
+ print(f"\nRunning on {len(specs)} circuits...")
1137
+ for spec in specs:
1138
+ try:
1139
+ circuit = Circuit(spec.path, cfg.device)
1140
+ results = run_all_methods(circuit, cfg)
1141
+ except Exception as e:
1142
+ print(f"ERROR on {spec.name}: {e}")
1143
+ elif args.circuit:
1144
+ circuit = load_circuit(args.circuit, cfg.device)
1145
+ results = run_all_methods(circuit, cfg)
1146
+
1147
+ if args.save and results:
1148
+ best = min(results.values(), key=lambda r: r.final_stats.get('magnitude', float('inf')))
1149
+ if best.fitness >= cfg.fitness_threshold:
1150
+ path = circuit.save(best.final_weights, f'pruned_{best.method}')
1151
+ print(f"\nSaved to: {path}")
1152
+ else:
1153
+ parser.print_help()
1154
+ print("\n\nExamples:")
1155
+ print(" python prune.py --list")
1156
+ print(" python prune.py threshold-hamming74decoder")
1157
+ print(" python prune.py threshold-hamming74decoder --methods mag,zero,evo")
1158
+ print(" python prune.py --all --max-inputs 8")
1159
+
1160
+
1161
+ if __name__ == '__main__':
1162
+ main()