CharlesCNorton commited on
Commit
af38f62
·
1 Parent(s): e69d4eb

Rewrite pruning framework with full GPU vectorization

Browse files

- Fully vectorized forward pass (no Python loops over cases)
- True batched population evaluation
- VRAM management with overflow protection
- Added pruning methods: neuron, lottery ticket, topology search
- Improved evolutionary: elite preservation, crossover, adaptive mutation
- Circuit-specific optimized forward functions (Hamming encoder/decoder)

Files changed (1) hide show
  1. prune.py +1062 -747
prune.py CHANGED
@@ -1,87 +1,164 @@
1
  """
2
- Unified Threshold Circuit Pruning Framework
3
- ============================================
4
 
5
- All pruning methods for threshold logic circuits in a single file.
6
 
7
  Methods:
8
- 1. Magnitude Reduction (sequential & batched GPU)
9
- 2. Zero Pruning (sparsification)
10
- 3. Weight Quantization (force to {-1,0,1})
11
- 4. Evolutionary Search (mutation + selection)
12
- 5. Simulated Annealing (gradual cooling)
13
- 6. Pareto Frontier (correctness vs size tradeoff)
 
 
 
14
 
15
  Usage:
16
- python prune.py threshold-hamming74decoder
17
- python prune.py threshold-hamming74decoder --methods magnitude,zero,evo
18
- python prune.py --list
19
- python prune.py --all --max-inputs 8
20
-
21
- Author: Pruning framework for phanerozoic/threshold-logic-circuits
22
  """
23
 
24
  import torch
25
- import torch.jit
26
  import json
27
  import time
28
  import random
29
  import argparse
30
- import importlib.util
31
- import sys
32
  from pathlib import Path
33
  from dataclasses import dataclass, field
34
- from typing import Dict, List, Tuple, Optional, Callable, Set
35
- from enum import Enum, auto
36
- from datetime import datetime
37
  from safetensors.torch import load_file, save_file
 
 
38
 
39
-
40
- # =============================================================================
41
- # CONFIGURATION
42
- # =============================================================================
43
 
44
  CIRCUITS_PATH = Path('D:/threshold-circuits')
45
  RESULTS_PATH = CIRCUITS_PATH / 'pruned_results'
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @dataclass
49
  class Config:
50
- """Global configuration for pruning."""
51
  device: str = 'cuda'
52
  fitness_threshold: float = 0.9999
53
- batch_size: int = 80000
54
  verbose: bool = True
 
55
 
56
- # Method toggles
57
  run_magnitude: bool = True
58
- run_batched_magnitude: bool = True
59
  run_zero: bool = True
60
  run_quantize: bool = True
61
  run_evolutionary: bool = True
62
  run_annealing: bool = True
 
 
 
63
  run_pareto: bool = True
64
 
65
- # Method-specific
66
  magnitude_passes: int = 100
67
- evo_generations: int = 1000
68
- evo_pop_size: int = 200
69
- evo_mutation_rate: float = 0.1
 
 
 
70
  evo_parsimony: float = 0.001
71
- annealing_iterations: int = 10000
 
 
72
  annealing_initial_temp: float = 10.0
73
- annealing_cooling: float = 0.995
 
74
  quantize_targets: List[float] = field(default_factory=lambda: [-1.0, 0.0, 1.0])
75
- pareto_levels: List[float] = field(default_factory=lambda: [1.0, 0.99, 0.95, 0.90])
 
 
 
76
 
 
 
 
77
 
78
- # =============================================================================
79
- # CIRCUIT LOADING
80
- # =============================================================================
81
 
82
  @dataclass
83
  class CircuitSpec:
84
- """Metadata for a threshold circuit."""
85
  name: str
86
  path: Path
87
  inputs: int
@@ -92,14 +169,36 @@ class CircuitSpec:
92
  description: str = ""
93
 
94
 
95
- class Circuit:
96
- """Threshold logic circuit loaded from safetensors."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def __init__(self, path: Path, device: str = 'cuda'):
99
  self.path = Path(path)
100
  self.device = device
101
  self.spec = self._load_spec()
102
  self.weights = self._load_weights()
 
 
 
 
 
103
 
104
  def _load_spec(self) -> CircuitSpec:
105
  with open(self.path / 'config.json') as f:
@@ -119,15 +218,315 @@ class Circuit:
119
  w = load_file(str(self.path / 'model.safetensors'))
120
  return {k: v.float().to(self.device) for k, v in w.items()}
121
 
122
- def clone(self) -> Dict[str, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  return {k: v.clone() for k, v in self.weights.items()}
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def stats(self, weights: Dict[str, torch.Tensor] = None) -> Dict:
126
  w = weights or self.weights
127
  total = sum(t.numel() for t in w.values())
128
  nonzero = sum((t != 0).sum().item() for t in w.values())
129
  mag = sum(t.abs().sum().item() for t in w.values())
130
- maxw = max(t.abs().max().item() for t in w.values())
131
  unique = set()
132
  for t in w.values():
133
  unique.update(t.flatten().tolist())
@@ -138,393 +537,238 @@ class Circuit:
138
  'sparsity': 1 - nonzero/total if total else 0,
139
  'magnitude': mag,
140
  'max_weight': maxw,
141
- 'unique_count': len(unique),
142
- 'unique_values': sorted(unique)
143
  }
144
 
145
- def save(self, weights: Dict[str, torch.Tensor], suffix: str = 'pruned'):
146
  path = self.path / f'model_{suffix}.safetensors'
147
  cpu_w = {k: v.cpu() for k, v in weights.items()}
148
  save_file(cpu_w, str(path))
149
  return path
150
 
151
 
152
- def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]:
153
- """Find all circuits in the collection."""
154
- circuits = []
155
- for d in base.iterdir():
156
- if d.is_dir() and (d / 'config.json').exists() and (d / 'model.safetensors').exists():
157
- try:
158
- c = Circuit(d, device='cpu')
159
- circuits.append(c.spec)
160
- except Exception as e:
161
- print(f"Skip {d.name}: {e}")
162
- return sorted(circuits, key=lambda x: (x.inputs, x.neurons))
163
-
164
-
165
- def load_circuit(name: str, device: str = 'cuda') -> Circuit:
166
- """Load circuit by name."""
167
- path = CIRCUITS_PATH / name
168
- if not path.exists():
169
- path = CIRCUITS_PATH / f'threshold-{name}'
170
- if not path.exists():
171
- raise ValueError(f"Circuit not found: {name}")
172
- return Circuit(path, device)
173
-
174
 
175
- # =============================================================================
176
- # GPU UTILITIES
177
- # =============================================================================
178
 
179
- def gpu_memory() -> Dict:
180
- if torch.cuda.is_available():
181
- return {
182
- 'allocated': torch.cuda.memory_allocated() / 1e9,
183
- 'reserved': torch.cuda.memory_reserved() / 1e9,
184
- 'total': torch.cuda.get_device_properties(0).total_memory / 1e9
185
- }
186
- return {'allocated': 0, 'reserved': 0, 'total': 0}
 
 
 
 
 
 
 
187
 
 
 
 
188
 
189
- def create_population(weights: Dict[str, torch.Tensor],
190
- pop_size: int, device: str) -> Dict[str, torch.Tensor]:
191
- """Replicate weights for batched evaluation."""
192
- return {
193
- k: v.unsqueeze(0).expand(pop_size, *v.shape).clone().to(device)
194
- for k, v in weights.items()
195
- }
196
 
 
 
 
197
 
198
- # =============================================================================
199
- # GENERIC EVALUATOR
200
- # =============================================================================
 
201
 
202
- class Evaluator:
203
- """
204
- Generic evaluator for any threshold circuit.
205
- Builds truth table and tests exhaustively.
206
- """
207
 
208
- def __init__(self, circuit: Circuit, forward_fn: Callable):
209
- self.circuit = circuit
210
- self.forward_fn = forward_fn
211
- self.device = circuit.device
212
- self.n_inputs = circuit.spec.inputs
213
- self.n_cases = 2 ** self.n_inputs
214
-
215
- self._build_inputs()
216
- self._build_expected()
217
-
218
- def _build_inputs(self):
219
- """Generate all 2^n input combinations."""
220
- if self.n_inputs > 20:
221
- raise ValueError(f"Input space too large: 2^{self.n_inputs}")
222
-
223
- idx = torch.arange(self.n_cases, device=self.device, dtype=torch.long)
224
- bits = torch.arange(self.n_inputs, device=self.device, dtype=torch.long)
225
- self.inputs = ((idx.unsqueeze(1) >> bits) & 1).float()
226
-
227
- def _build_expected(self):
228
- """Compute expected outputs using original weights."""
229
- self.expected = self.forward_fn(self.inputs, self.circuit.weights)
230
-
231
- def evaluate(self, weights: Dict[str, torch.Tensor]) -> float:
232
- """Single evaluation: returns fitness 0.0-1.0"""
233
- outputs = self.forward_fn(self.inputs, weights)
234
- correct = (outputs == self.expected).all(dim=-1).float().sum()
235
- return (correct / self.n_cases).item()
236
-
237
- def evaluate_batch(self, population: Dict[str, torch.Tensor]) -> torch.Tensor:
238
- """Batch evaluation: returns [pop_size] fitness tensor"""
239
- pop_size = next(iter(population.values())).shape[0]
240
  fitness = torch.zeros(pop_size, device=self.device)
241
 
242
- for i in range(pop_size):
243
- w = {k: v[i] for k, v in population.items()}
244
- outputs = self.forward_fn(self.inputs, w)
245
- correct = (outputs == self.expected).all(dim=-1).float().sum()
246
- fitness[i] = correct / self.n_cases
 
247
 
248
  return fitness
249
 
 
 
 
 
 
250
 
251
- # =============================================================================
252
- # PRUNING METHODS
253
- # =============================================================================
 
254
 
255
- @dataclass
256
- class PruneResult:
257
- """Result from a pruning method."""
258
- method: str
259
- original_stats: Dict
260
- final_stats: Dict
261
- final_weights: Dict[str, torch.Tensor]
262
- fitness: float
263
- reductions: int
264
- time_seconds: float
265
- history: List[Dict] = field(default_factory=list)
266
 
 
267
 
268
- def get_candidates(weights: Dict[str, torch.Tensor]) -> List[Tuple[str, int, tuple, float]]:
269
- """Get all non-zero weight positions."""
270
- candidates = []
271
- for name, tensor in weights.items():
272
- flat = tensor.flatten()
273
- for i in range(len(flat)):
274
- val = flat[i].item()
275
- if val != 0:
276
- candidates.append((name, i, tensor.shape, val))
277
- return candidates
278
 
 
 
 
279
 
280
- def apply_reduction(weights: Dict[str, torch.Tensor],
281
- name: str, idx: int, shape: tuple, old_val: float):
282
- """Apply magnitude reduction: move weight 1 step toward zero."""
283
- new_val = old_val - 1 if old_val > 0 else old_val + 1
284
- flat = weights[name].flatten()
285
- flat[idx] = new_val
286
- weights[name] = flat.view(shape)
287
 
 
288
 
289
- def revert_reduction(weights: Dict[str, torch.Tensor],
290
- name: str, idx: int, shape: tuple, old_val: float):
291
- """Revert a reduction."""
292
- flat = weights[name].flatten()
293
- flat[idx] = old_val
294
- weights[name] = flat.view(shape)
295
 
 
 
 
 
 
 
296
 
297
- # -----------------------------------------------------------------------------
298
- # Method 1: Sequential Magnitude Reduction
299
- # -----------------------------------------------------------------------------
300
 
301
- def prune_magnitude(weights: Dict[str, torch.Tensor],
302
- eval_fn: Callable[[Dict], float],
303
- cfg: Config) -> PruneResult:
304
- """Reduce weight magnitudes one at a time."""
305
  start = time.perf_counter()
306
- weights = {k: v.clone() for k, v in weights.items()}
307
- original = _stats(weights)
308
- reductions = 0
309
  history = []
 
310
 
311
  if cfg.verbose:
312
- print(f" Starting magnitude reduction...")
313
  print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
314
 
315
  for pass_num in range(cfg.magnitude_passes):
316
- candidates = get_candidates(weights)
 
 
 
 
 
 
 
317
  if not candidates:
318
- if cfg.verbose:
319
- print(f" No candidates remaining at pass {pass_num}")
320
  break
321
 
322
- if cfg.verbose:
323
- print(f" Pass {pass_num}: testing {len(candidates)} candidates...")
324
-
325
  pass_reductions = 0
326
- tested = 0
327
  for name, idx, shape, old_val in candidates:
328
- apply_reduction(weights, name, idx, shape, old_val)
329
- tested += 1
 
 
 
 
 
330
 
331
- fitness = eval_fn(weights)
332
  if fitness >= cfg.fitness_threshold:
333
  pass_reductions += 1
334
- reductions += 1
335
- if cfg.verbose:
336
- new_val = old_val - 1 if old_val > 0 else old_val + 1
337
- print(f" ✓ {name}[{idx}]: {old_val} -> {new_val}")
338
  else:
339
- revert_reduction(weights, name, idx, shape, old_val)
 
 
340
 
341
- if cfg.verbose and tested % 50 == 0:
342
- s = _stats(weights)
343
- print(f" Progress: {tested}/{len(candidates)}, reductions={pass_reductions}, mag={s['magnitude']:.0f}")
344
 
345
- history.append({'pass': pass_num, 'reductions': pass_reductions})
346
-
347
- s = _stats(weights)
348
  if cfg.verbose:
349
- print(f" Pass {pass_num} complete: +{pass_reductions} reductions, mag={s['magnitude']:.0f}, nonzero={s['nonzero']}")
350
 
351
  if pass_reductions == 0:
352
- if cfg.verbose:
353
- print(f" No progress at pass {pass_num}, stopping.")
354
  break
355
 
356
  return PruneResult(
357
  method='magnitude',
358
  original_stats=original,
359
- final_stats=_stats(weights),
360
  final_weights=weights,
361
- fitness=eval_fn(weights),
362
- reductions=reductions,
363
  time_seconds=time.perf_counter() - start,
364
  history=history
365
  )
366
 
367
 
368
- # -----------------------------------------------------------------------------
369
- # Method 2: Batched GPU Magnitude Reduction
370
- # -----------------------------------------------------------------------------
371
-
372
- def prune_magnitude_batched(weights: Dict[str, torch.Tensor],
373
- eval_fn: Callable[[Dict], float],
374
- batch_eval_fn: Callable[[Dict], torch.Tensor],
375
- cfg: Config) -> PruneResult:
376
- """GPU-batched magnitude reduction."""
377
- start = time.perf_counter()
378
- weights = {k: v.clone() for k, v in weights.items()}
379
- original = _stats(weights)
380
- device = cfg.device
381
- reductions = 0
382
- history = []
383
-
384
- for pass_num in range(cfg.magnitude_passes):
385
- candidates = get_candidates(weights)
386
- if not candidates:
387
- break
388
-
389
- # Phase 1: Batch test all candidates
390
- successful = []
391
- n = len(candidates)
392
-
393
- for batch_start in range(0, n, cfg.batch_size):
394
- batch = candidates[batch_start:batch_start + cfg.batch_size]
395
- batch_len = len(batch)
396
-
397
- pop = {name: tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device)
398
- for name, tensor in weights.items()}
399
-
400
- for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
401
- new_val = old_val - 1 if old_val > 0 else old_val + 1
402
- flat_view = pop[name][pop_idx].flatten()
403
- flat_view[flat_idx] = new_val
404
-
405
- fitness = batch_eval_fn(pop)
406
-
407
- for pop_idx, cand in enumerate(batch):
408
- if fitness[pop_idx].item() >= cfg.fitness_threshold:
409
- successful.append(cand)
410
-
411
- # Phase 2: Apply with conflict resolution
412
- pass_reductions = 0
413
- for name, idx, shape, old_val in successful:
414
- current_val = weights[name].flatten()[idx].item()
415
- if current_val == old_val:
416
- apply_reduction(weights, name, idx, shape, old_val)
417
- if eval_fn(weights) >= cfg.fitness_threshold:
418
- pass_reductions += 1
419
- reductions += 1
420
- else:
421
- revert_reduction(weights, name, idx, shape, old_val)
422
-
423
- history.append({'pass': pass_num, 'reductions': pass_reductions, 'candidates': len(successful)})
424
-
425
- if cfg.verbose:
426
- s = _stats(weights)
427
- print(f" Pass {pass_num}: {pass_reductions}/{len(successful)} applied, mag={s['magnitude']:.0f}")
428
-
429
- if pass_reductions == 0:
430
- break
431
-
432
- return PruneResult(
433
- method='batched_magnitude',
434
- original_stats=original,
435
- final_stats=_stats(weights),
436
- final_weights=weights,
437
- fitness=eval_fn(weights),
438
- reductions=reductions,
439
- time_seconds=time.perf_counter() - start,
440
- history=history
441
- )
442
-
443
-
444
- # -----------------------------------------------------------------------------
445
- # Method 3: Zero Pruning
446
- # -----------------------------------------------------------------------------
447
-
448
- def prune_zero(weights: Dict[str, torch.Tensor],
449
- eval_fn: Callable[[Dict], float],
450
  cfg: Config) -> PruneResult:
451
- """Try setting weights directly to zero."""
452
  start = time.perf_counter()
453
- weights = {k: v.clone() for k, v in weights.items()}
454
- original = _stats(weights)
 
 
 
 
 
 
 
 
455
 
456
- candidates = get_candidates(weights)
457
  random.shuffle(candidates)
458
 
459
  if cfg.verbose:
460
- print(f" Starting zero pruning...")
461
- print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
462
- print(f" Testing {len(candidates)} candidates (random order)...")
463
 
464
- reductions = 0
465
- tested = 0
466
  for name, idx, shape, old_val in candidates:
467
  flat = weights[name].flatten()
468
  flat[idx] = 0
469
  weights[name] = flat.view(shape)
470
- tested += 1
471
 
472
- if eval_fn(weights) >= cfg.fitness_threshold:
473
- reductions += 1
474
- if cfg.verbose:
475
- print(f" ✓ {name}[{idx}]: {old_val} -> 0 (zeroed)")
476
  else:
477
  flat = weights[name].flatten()
478
  flat[idx] = old_val
479
  weights[name] = flat.view(shape)
480
 
481
- if cfg.verbose and tested % 50 == 0:
482
- s = _stats(weights)
483
- print(f" Progress: {tested}/{len(candidates)}, zeroed={reductions}, mag={s['magnitude']:.0f}")
484
-
485
  if cfg.verbose:
486
- s = _stats(weights)
487
- print(f" Zero pruning complete: {reductions} weights zeroed")
488
- print(f" Final: mag={s['magnitude']:.0f}, nonzero={s['nonzero']}")
489
 
490
  return PruneResult(
491
  method='zero',
492
  original_stats=original,
493
- final_stats=_stats(weights),
494
  final_weights=weights,
495
- fitness=eval_fn(weights),
496
- reductions=reductions,
497
  time_seconds=time.perf_counter() - start
498
  )
499
 
500
 
501
- # -----------------------------------------------------------------------------
502
- # Method 4: Quantization
503
- # -----------------------------------------------------------------------------
504
-
505
- def prune_quantize(weights: Dict[str, torch.Tensor],
506
- eval_fn: Callable[[Dict], float],
507
  cfg: Config) -> PruneResult:
508
- """Force weights to target set (default: {-1,0,1})."""
509
  start = time.perf_counter()
510
- weights = {k: v.clone() for k, v in weights.items()}
511
- original = _stats(weights)
512
- target = torch.tensor(cfg.quantize_targets, device=weights[next(iter(weights))].device)
513
  target_set = set(cfg.quantize_targets)
514
 
515
  if cfg.verbose:
516
- print(f" Starting quantization...")
517
- print(f" Target values: {sorted(cfg.quantize_targets)}")
518
- print(f" Original unique values: {original.get('unique_count', len(set(v.item() for t in weights.values() for v in t.flatten())))}")
519
- print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
520
-
521
- # Count how many need quantizing
522
- needs_quant = sum(1 for t in weights.values() for v in t.flatten() if v.item() not in target_set)
523
- if cfg.verbose:
524
- print(f" Weights needing quantization: {needs_quant}")
525
 
526
- reductions = 0
527
- tested = 0
528
  for name, tensor in list(weights.items()):
529
  flat = tensor.flatten()
530
  for i in range(len(flat)):
@@ -535,539 +779,574 @@ def prune_quantize(weights: Dict[str, torch.Tensor],
535
 
536
  flat[i] = closest
537
  weights[name] = flat.view(tensor.shape)
538
- tested += 1
539
 
540
- if eval_fn(weights) >= cfg.fitness_threshold:
541
- reductions += 1
542
- if cfg.verbose:
543
- print(f" ✓ {name}[{i}]: {old_val} -> {closest}")
544
  else:
545
  flat[i] = old_val
546
  weights[name] = flat.view(tensor.shape)
547
 
548
- if cfg.verbose and tested % 20 == 0:
549
- print(f" Progress: {tested}/{needs_quant}, quantized={reductions}")
550
-
551
  if cfg.verbose:
552
- s = _stats(weights)
553
- unique_now = len(set(v.item() for t in weights.values() for v in t.flatten()))
554
- print(f" Quantization complete: {reductions}/{tested} quantized")
555
- print(f" Final unique values: {unique_now}")
556
- print(f" Final: mag={s['magnitude']:.0f}, nonzero={s['nonzero']}")
557
 
558
  return PruneResult(
559
  method='quantize',
560
  original_stats=original,
561
- final_stats=_stats(weights),
562
  final_weights=weights,
563
- fitness=eval_fn(weights),
564
- reductions=reductions,
565
  time_seconds=time.perf_counter() - start
566
  )
567
 
568
 
569
- # -----------------------------------------------------------------------------
570
- # Method 5: Evolutionary Search
571
- # -----------------------------------------------------------------------------
572
-
573
- def prune_evolutionary(weights: Dict[str, torch.Tensor],
574
- batch_eval_fn: Callable[[Dict], torch.Tensor],
575
  cfg: Config) -> PruneResult:
576
- """Evolutionary search with parsimony pressure."""
 
 
 
 
 
 
 
577
  start = time.perf_counter()
578
- original = _stats(weights)
579
- device = cfg.device
580
- pop_size = cfg.evo_pop_size
 
581
 
582
  if cfg.verbose:
583
- print(f" Starting evolutionary search...")
584
- print(f" Population: {pop_size}, Generations: {cfg.evo_generations}")
585
- print(f" Mutation rate: {cfg.evo_mutation_rate}, Parsimony: {cfg.evo_parsimony}")
586
- print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
587
 
588
- # Initialize population
589
- pop = {k: v.unsqueeze(0).expand(pop_size, *v.shape).clone().to(device)
590
- for k, v in weights.items()}
591
 
592
- best_weights = {k: v.clone() for k, v in weights.items()}
 
 
 
 
 
593
  best_score = -float('inf')
594
  best_fitness = 0.0
 
 
595
  history = []
596
- improved_at = 0
597
 
598
  for gen in range(cfg.evo_generations):
599
- # Evaluate
600
- fitness = batch_eval_fn(pop)
601
-
602
- # Compute magnitude penalty
603
- mags = torch.stack([
604
- sum(pop[name][i].abs().sum() for name in pop)
605
- for i in range(pop_size)
606
- ])
607
- adjusted = fitness - cfg.evo_parsimony * mags
608
-
609
- # Track best
610
- best_idx = adjusted.argmax().item()
611
- gen_best_fitness = fitness[best_idx].item()
612
- gen_best_adj = adjusted[best_idx].item()
613
- gen_best_mag = mags[best_idx].item()
614
-
615
- if gen_best_fitness >= cfg.fitness_threshold:
616
- if gen_best_adj > best_score:
617
- best_score = gen_best_adj
618
- best_fitness = gen_best_fitness
619
- best_weights = {k: v[best_idx].clone() for k, v in pop.items()}
620
- improved_at = gen
621
- if cfg.verbose:
622
- s = _stats(best_weights)
623
- print(f" Gen {gen}: NEW BEST! score={best_score:.4f}, fitness={best_fitness:.4f}, mag={s['magnitude']:.0f}")
624
-
625
- # Stats for this generation
626
  valid_mask = fitness >= cfg.fitness_threshold
627
  n_valid = valid_mask.sum().item()
628
- avg_fitness = fitness.mean().item()
629
- avg_mag = mags.mean().item()
630
 
631
- if gen % 50 == 0:
632
- s = _stats(best_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  if cfg.verbose:
634
- print(f" Gen {gen}: valid={n_valid}/{pop_size}, avg_fit={avg_fitness:.4f}, avg_mag={avg_mag:.0f}, best_mag={s['magnitude']:.0f}")
635
- history.append({'gen': gen, 'score': best_score, 'mag': s['magnitude'], 'n_valid': n_valid})
636
-
637
- # Selection + mutation
638
- probs = torch.softmax(adjusted, dim=0)
639
- indices = torch.multinomial(probs, pop_size, replacement=True)
640
-
641
- new_pop = {}
642
- for name, tensor in pop.items():
643
- selected = tensor[indices].clone()
644
- mask = torch.rand_like(selected) < cfg.evo_mutation_rate
645
- mutations = torch.randint(-1, 2, selected.shape, device=device).float()
646
- selected = selected + mask.float() * mutations
647
- new_pop[name] = selected
648
- pop = new_pop
649
-
650
- # Final report
651
- final_stats = _stats(best_weights)
652
- elapsed = time.perf_counter() - start
653
 
654
- if cfg.verbose:
655
- print(f" Evolution complete in {elapsed:.1f}s")
656
- print(f" Best found at generation {improved_at}")
657
- print(f" Final: mag={final_stats['magnitude']:.0f}, nonzero={final_stats['nonzero']}")
658
- reduction_pct = 100 * (1 - final_stats['magnitude'] / original['magnitude'])
659
- print(f" Magnitude reduction: {reduction_pct:.1f}%")
660
 
661
  return PruneResult(
662
  method='evolutionary',
663
  original_stats=original,
664
  final_stats=final_stats,
665
  final_weights=best_weights,
666
- fitness=best_score + cfg.evo_parsimony * final_stats['magnitude'],
667
- reductions=int(original['magnitude'] - final_stats['magnitude']),
668
- time_seconds=elapsed,
669
- history=history
670
  )
671
 
672
 
673
- # -----------------------------------------------------------------------------
674
- # Method 6: Simulated Annealing
675
- # -----------------------------------------------------------------------------
676
-
677
- def prune_annealing(weights: Dict[str, torch.Tensor],
678
- eval_fn: Callable[[Dict], float],
679
  cfg: Config) -> PruneResult:
680
- """Simulated annealing for circuit minimization."""
681
  start = time.perf_counter()
682
- weights = {k: v.clone() for k, v in weights.items()}
683
- original = _stats(weights)
 
 
 
 
 
 
 
 
684
 
685
- current = weights
686
- current_energy = _energy(current, eval_fn, cfg)
687
  best = {k: v.clone() for k, v in current.items()}
688
  best_energy = current_energy
 
689
 
690
  temp = cfg.annealing_initial_temp
691
  history = []
692
 
 
 
 
693
  for i in range(cfg.annealing_iterations):
694
- # Perturb
695
  neighbor = {k: v.clone() for k, v in current.items()}
696
  name = random.choice(list(neighbor.keys()))
697
  flat = neighbor[name].flatten()
698
  idx = random.randint(0, len(flat) - 1)
699
- mutation = random.choice([-1, 1, 0])
 
700
  if mutation == 0:
701
  flat[idx] = 0
702
  else:
703
  flat[idx] = flat[idx] + mutation
704
  neighbor[name] = flat.view(neighbor[name].shape)
705
 
706
- neighbor_energy = _energy(neighbor, eval_fn, cfg)
 
 
 
 
 
 
 
707
  delta = neighbor_energy - current_energy
708
 
709
  if delta < 0 or random.random() < math.exp(-delta / max(temp, 1e-10)):
710
  current = neighbor
711
  current_energy = neighbor_energy
 
712
 
713
- if current_energy < best_energy:
714
- if eval_fn(current) >= cfg.fitness_threshold:
715
- best = {k: v.clone() for k, v in current.items()}
716
- best_energy = current_energy
717
 
718
  temp *= cfg.annealing_cooling
719
 
720
- if i % 1000 == 0:
721
- s = _stats(best)
 
722
  if cfg.verbose:
723
- print(f" Iter {i}: temp={temp:.4f}, mag={s['magnitude']:.0f}")
724
- history.append({'iter': i, 'temp': temp, 'mag': s['magnitude']})
725
 
726
  return PruneResult(
727
  method='annealing',
728
  original_stats=original,
729
- final_stats=_stats(best),
730
  final_weights=best,
731
- fitness=eval_fn(best),
732
- reductions=int(original['magnitude'] - _stats(best)['magnitude']),
733
  time_seconds=time.perf_counter() - start,
734
  history=history
735
  )
736
 
737
 
738
- def _energy(weights, eval_fn, cfg):
739
- fitness = eval_fn(weights)
740
- mag = sum(t.abs().sum().item() for t in weights.values())
741
- if fitness < cfg.fitness_threshold:
742
- return 1e6 + mag
743
- return mag
744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
 
746
- # -----------------------------------------------------------------------------
747
- # Method 7: Pareto Frontier
748
- # -----------------------------------------------------------------------------
749
 
750
- def prune_pareto(weights: Dict[str, torch.Tensor],
751
- eval_fn: Callable[[Dict], float],
752
- cfg: Config) -> PruneResult:
753
- """Search Pareto frontier of correctness vs size."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  start = time.perf_counter()
755
- original = _stats(weights)
756
- frontier = []
757
 
758
- for target in cfg.pareto_levels:
759
- if cfg.verbose:
760
- print(f" Target fitness >= {target}")
761
 
762
- relaxed_cfg = Config(
763
- device=cfg.device,
764
- fitness_threshold=target,
765
- magnitude_passes=50,
766
- verbose=False
767
- )
768
 
769
- result = prune_magnitude({k: v.clone() for k, v in weights.items()}, eval_fn, relaxed_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
770
 
771
- frontier.append({
772
- 'target': target,
773
- 'actual': result.fitness,
774
- 'magnitude': result.final_stats['magnitude'],
775
- 'nonzero': result.final_stats['nonzero']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  })
777
 
778
  if cfg.verbose:
779
- print(f" -> fitness={result.fitness:.4f}, mag={result.final_stats['magnitude']:.0f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
780
 
781
  return PruneResult(
782
- method='pareto',
783
  original_stats=original,
784
- final_stats=frontier[-1] if frontier else original,
785
  final_weights=weights,
786
- fitness=frontier[0]['actual'] if frontier else 1.0,
787
- reductions=len(frontier),
788
  time_seconds=time.perf_counter() - start,
789
- history=frontier
790
  )
791
 
792
 
793
- # -----------------------------------------------------------------------------
794
- # Helpers
795
- # -----------------------------------------------------------------------------
796
-
797
- def _stats(weights: Dict[str, torch.Tensor]) -> Dict:
798
- total = sum(t.numel() for t in weights.values())
799
- nonzero = sum((t != 0).sum().item() for t in weights.values())
800
- mag = sum(t.abs().sum().item() for t in weights.values())
801
- maxw = max(t.abs().max().item() for t in weights.values()) if weights else 0
802
- return {'total': total, 'nonzero': nonzero, 'magnitude': mag, 'max': maxw}
803
-
804
 
805
- import math
 
 
 
 
806
 
 
807
 
808
- # =============================================================================
809
- # CIRCUIT-SPECIFIC FORWARD FUNCTIONS
810
- # =============================================================================
 
 
 
 
 
 
 
811
 
812
- def make_hamming_decoder_forward(device='cuda'):
813
- """Create forward function for Hamming(7,4) decoder."""
814
 
815
- def forward(inputs, weights):
816
- """
817
- Batched forward pass for Hamming decoder.
818
- inputs: [n_cases, 7]
819
- weights: dict of weight tensors
820
- Returns: [n_cases, 4]
821
- """
822
- n_cases = inputs.shape[0]
823
- w = weights
824
- outputs = []
825
-
826
- for case_idx in range(n_cases):
827
- c = [inputs[case_idx, i].item() for i in range(7)]
828
-
829
- def xor2(a, b, prefix):
830
- inp = torch.tensor([float(a), float(b)], device=device)
831
- or_out = float((inp * w[f'{prefix}.layer1.or.weight'].flatten()[:2]).sum() +
832
- w[f'{prefix}.layer1.or.bias'].squeeze() >= 0)
833
- nand_out = float((inp * w[f'{prefix}.layer1.nand.weight'].flatten()[:2]).sum() +
834
- w[f'{prefix}.layer1.nand.bias'].squeeze() >= 0)
835
- l1 = torch.tensor([or_out, nand_out], device=device)
836
- return int((l1 * w[f'{prefix}.layer2.weight'].flatten()).sum() +
837
- w[f'{prefix}.layer2.bias'].squeeze() >= 0)
838
-
839
- def xor4(indices, prefix):
840
- i0, i1, i2, i3 = indices
841
- inp = torch.tensor([float(c[i]) for i in range(7)], device=device)
842
-
843
- or_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.or.weight'].flatten()).sum() +
844
- w[f'{prefix}.xor_{i0}{i1}.layer1.or.bias'].squeeze() >= 0)
845
- nand_out = float((inp * w[f'{prefix}.xor_{i0}{i1}.layer1.nand.weight'].flatten()).sum() +
846
- w[f'{prefix}.xor_{i0}{i1}.layer1.nand.bias'].squeeze() >= 0)
847
- xor_ab = int((torch.tensor([or_out, nand_out], device=device) *
848
- w[f'{prefix}.xor_{i0}{i1}.layer2.weight'].flatten()).sum() +
849
- w[f'{prefix}.xor_{i0}{i1}.layer2.bias'].squeeze() >= 0)
850
-
851
- or_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.or.weight'].flatten()).sum() +
852
- w[f'{prefix}.xor_{i2}{i3}.layer1.or.bias'].squeeze() >= 0)
853
- nand_out = float((inp * w[f'{prefix}.xor_{i2}{i3}.layer1.nand.weight'].flatten()).sum() +
854
- w[f'{prefix}.xor_{i2}{i3}.layer1.nand.bias'].squeeze() >= 0)
855
- xor_cd = int((torch.tensor([or_out, nand_out], device=device) *
856
- w[f'{prefix}.xor_{i2}{i3}.layer2.weight'].flatten()).sum() +
857
- w[f'{prefix}.xor_{i2}{i3}.layer2.bias'].squeeze() >= 0)
858
-
859
- inp2 = torch.tensor([float(xor_ab), float(xor_cd)], device=device)
860
- or_out = float((inp2 * w[f'{prefix}.xor_final.layer1.or.weight'].flatten()).sum() +
861
- w[f'{prefix}.xor_final.layer1.or.bias'].squeeze() >= 0)
862
- nand_out = float((inp2 * w[f'{prefix}.xor_final.layer1.nand.weight'].flatten()).sum() +
863
- w[f'{prefix}.xor_final.layer1.nand.bias'].squeeze() >= 0)
864
- return int((torch.tensor([or_out, nand_out], device=device) *
865
- w[f'{prefix}.xor_final.layer2.weight'].flatten()).sum() +
866
- w[f'{prefix}.xor_final.layer2.bias'].squeeze() >= 0)
867
-
868
- s1 = xor4([0, 2, 4, 6], 's1')
869
- s2 = xor4([1, 2, 5, 6], 's2')
870
- s3 = xor4([3, 4, 5, 6], 's3')
871
-
872
- syndrome = torch.tensor([float(s1), float(s2), float(s3)], device=device)
873
-
874
- flip3 = int((syndrome * w['flip3.weight'].flatten()).sum() + w['flip3.bias'].squeeze() >= 0)
875
- flip5 = int((syndrome * w['flip5.weight'].flatten()).sum() + w['flip5.bias'].squeeze() >= 0)
876
- flip6 = int((syndrome * w['flip6.weight'].flatten()).sum() + w['flip6.bias'].squeeze() >= 0)
877
- flip7 = int((syndrome * w['flip7.weight'].flatten()).sum() + w['flip7.bias'].squeeze() >= 0)
878
-
879
- d1 = xor2(c[2], flip3, 'd1.xor')
880
- d2 = xor2(c[4], flip5, 'd2.xor')
881
- d3 = xor2(c[5], flip6, 'd3.xor')
882
- d4 = xor2(c[6], flip7, 'd4.xor')
883
-
884
- outputs.append([d1, d2, d3, d4])
885
-
886
- return torch.tensor(outputs, device=device, dtype=torch.float32)
887
-
888
- # Build test cases with error injection
889
- def hamming_encode(data):
890
- d1, d2, d3, d4 = (data >> 0) & 1, (data >> 1) & 1, (data >> 2) & 1, (data >> 3) & 1
891
- p1, p2, p3 = d1 ^ d2 ^ d4, d1 ^ d3 ^ d4, d2 ^ d3 ^ d4
892
- return (p1 << 0) | (p2 << 1) | (d1 << 2) | (p3 << 3) | (d2 << 4) | (d3 << 5) | (d4 << 6)
893
-
894
- inputs_list, expected_list = [], []
895
- for data in range(16):
896
- cw = hamming_encode(data)
897
- inputs_list.append([(cw >> i) & 1 for i in range(7)])
898
- expected_list.append([(data >> i) & 1 for i in range(4)])
899
- for data in range(16):
900
- cw = hamming_encode(data)
901
- for flip in range(7):
902
- corrupted = cw ^ (1 << flip)
903
- inputs_list.append([(corrupted >> i) & 1 for i in range(7)])
904
- expected_list.append([(data >> i) & 1 for i in range(4)])
905
 
906
- test_inputs = torch.tensor(inputs_list, device=device, dtype=torch.float32)
907
- test_expected = torch.tensor(expected_list, device=device, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
908
 
909
- return forward, test_inputs, test_expected
 
 
 
 
 
 
 
 
910
 
911
 
912
- def make_generic_forward(circuit: Circuit):
913
- """Create generic forward by loading model.py dynamically."""
914
- model_py = circuit.path / 'model.py'
915
- if not model_py.exists():
916
- return None, None, None
 
917
 
918
- spec = importlib.util.spec_from_file_location("circuit_model", model_py)
919
- module = importlib.util.module_from_spec(spec)
920
- spec.loader.exec_module(module)
921
 
922
- # Find the main function
923
- fn_names = [circuit.spec.name.replace('threshold-', '').replace('-', '_'),
924
- 'forward', 'evaluate', 'run']
 
 
 
 
 
925
 
926
- main_fn = None
927
- for name in dir(module):
928
- if name.lower() in [n.lower() for n in fn_names] and callable(getattr(module, name)):
929
- main_fn = getattr(module, name)
930
- break
931
 
932
- if main_fn is None:
933
- return None, None, None
934
-
935
- # Build inputs
936
- n = circuit.spec.inputs
937
- n_cases = 2 ** n
938
- device = circuit.device
939
-
940
- idx = torch.arange(n_cases, device=device, dtype=torch.long)
941
- bits = torch.arange(n, device=device, dtype=torch.long)
942
- inputs = ((idx.unsqueeze(1) >> bits) & 1).float()
943
-
944
- # Compute expected
945
- outputs = []
946
- for i in range(n_cases):
947
- args = [int(inputs[i, j].item()) for j in range(n)]
948
- result = main_fn(*args, circuit.weights)
949
- if isinstance(result, (list, tuple)):
950
- outputs.append([float(x) for x in result])
951
- else:
952
- outputs.append([float(result)])
953
- expected = torch.tensor(outputs, device=device, dtype=torch.float32)
954
-
955
- def forward(inp, weights):
956
- out = []
957
- for i in range(inp.shape[0]):
958
- args = [int(inp[i, j].item()) for j in range(n)]
959
- result = main_fn(*args, weights)
960
- if isinstance(result, (list, tuple)):
961
- out.append([float(x) for x in result])
962
- else:
963
- out.append([float(result)])
964
- return torch.tensor(out, device=device, dtype=torch.float32)
965
 
966
- return forward, inputs, expected
 
967
 
 
 
 
 
 
 
 
 
 
968
 
969
- # =============================================================================
970
- # MAIN ORCHESTRATOR
971
- # =============================================================================
972
 
973
- def run_all_methods(circuit: Circuit, cfg: Config) -> Dict[str, PruneResult]:
974
- """Run all enabled pruning methods on a circuit."""
975
 
976
  print(f"\n{'='*70}")
977
  print(f" PRUNING: {circuit.spec.name}")
978
  print(f"{'='*70}")
979
 
 
 
 
 
980
  original = circuit.stats()
981
  print(f" Inputs: {circuit.spec.inputs}, Outputs: {circuit.spec.outputs}")
982
  print(f" Neurons: {circuit.spec.neurons}, Layers: {circuit.spec.layers}")
983
  print(f" Parameters: {original['total']}, Non-zero: {original['nonzero']}")
984
  print(f" Magnitude: {original['magnitude']:.0f}")
 
985
  print(f"{'='*70}")
986
 
987
- # Get forward function
988
- if 'hamming74decoder' in circuit.spec.name:
989
- forward_fn, test_inputs, test_expected = make_hamming_decoder_forward(cfg.device)
990
- else:
991
- forward_fn, test_inputs, test_expected = make_generic_forward(circuit)
992
-
993
- if forward_fn is None:
994
- print("ERROR: Could not create forward function")
995
- return {}
996
 
997
- # Create evaluators
998
- def eval_fn(weights):
999
- outputs = forward_fn(test_inputs, weights)
1000
- correct = (outputs == test_expected).all(dim=-1).float().sum()
1001
- return (correct / test_inputs.shape[0]).item()
1002
-
1003
- def batch_eval_fn(population):
1004
- pop_size = next(iter(population.values())).shape[0]
1005
- fitness = torch.zeros(pop_size, device=cfg.device)
1006
- for i in range(pop_size):
1007
- w = {k: v[i] for k, v in population.items()}
1008
- outputs = forward_fn(test_inputs, w)
1009
- correct = (outputs == test_expected).all(dim=-1).float().sum()
1010
- fitness[i] = correct / test_inputs.shape[0]
1011
- return fitness
1012
 
1013
- # Verify initial
1014
- initial = eval_fn(circuit.weights)
1015
- print(f"\n Initial fitness: {initial:.6f}")
1016
- if initial < cfg.fitness_threshold:
1017
  print(" ERROR: Circuit doesn't pass baseline!")
1018
  return {}
1019
 
1020
  results = {}
1021
 
1022
- # Run methods
1023
- if cfg.run_magnitude:
1024
- print(f"\n[1] MAGNITUDE REDUCTION (sequential)")
1025
- results['magnitude'] = prune_magnitude(circuit.clone(), eval_fn, cfg)
1026
- _print_result(results['magnitude'])
1027
-
1028
- if cfg.run_batched_magnitude:
1029
- print(f"\n[2] MAGNITUDE REDUCTION (batched GPU)")
1030
- results['batched'] = prune_magnitude_batched(circuit.clone(), eval_fn, batch_eval_fn, cfg)
1031
- _print_result(results['batched'])
1032
-
1033
- if cfg.run_zero:
1034
- print(f"\n[3] ZERO PRUNING")
1035
- results['zero'] = prune_zero(circuit.clone(), eval_fn, cfg)
1036
- _print_result(results['zero'])
1037
-
1038
- if cfg.run_quantize:
1039
- print(f"\n[4] QUANTIZATION")
1040
- results['quantize'] = prune_quantize(circuit.clone(), eval_fn, cfg)
1041
- _print_result(results['quantize'])
1042
-
1043
- if cfg.run_evolutionary:
1044
- print(f"\n[5] EVOLUTIONARY")
1045
- results['evolutionary'] = prune_evolutionary(circuit.clone(), batch_eval_fn, cfg)
1046
- _print_result(results['evolutionary'])
1047
-
1048
- if cfg.run_annealing:
1049
- print(f"\n[6] SIMULATED ANNEALING")
1050
- results['annealing'] = prune_annealing(circuit.clone(), eval_fn, cfg)
1051
- _print_result(results['annealing'])
1052
-
1053
- if cfg.run_pareto:
1054
- print(f"\n[7] PARETO FRONTIER")
1055
- results['pareto'] = prune_pareto(circuit.clone(), eval_fn, cfg)
1056
- _print_result(results['pareto'])
1057
-
1058
- # Summary
1059
  print(f"\n{'='*70}")
1060
  print(" SUMMARY")
1061
  print(f"{'='*70}")
1062
- print(f"\n{'Method':<20} {'Fitness':<10} {'Magnitude':<12} {'Nonzero':<10} {'Time':<10}")
1063
  print("-" * 70)
1064
- print(f"{'Original':<20} {'1.0000':<10} {original['magnitude']:<12.0f} {original['nonzero']:<10} {'-':<10}")
1065
 
1066
  best_method, best_mag = None, float('inf')
1067
  for name, r in sorted(results.items(), key=lambda x: x[1].final_stats.get('magnitude', float('inf'))):
1068
  mag = r.final_stats.get('magnitude', 0)
1069
  nz = r.final_stats.get('nonzero', 0)
1070
- print(f"{name:<20} {r.fitness:<10.4f} {mag:<12.0f} {nz:<10} {r.time_seconds:<10.1f}s")
 
 
1071
  if r.fitness >= cfg.fitness_threshold and mag < best_mag:
1072
  best_mag = mag
1073
  best_method = name
@@ -1083,25 +1362,47 @@ def _print_result(r: PruneResult):
1083
  print(f" Fitness: {r.fitness:.6f}")
1084
  print(f" Magnitude: {r.final_stats.get('magnitude', 0):.0f}")
1085
  print(f" Nonzero: {r.final_stats.get('nonzero', 0)}")
 
1086
  print(f" Time: {r.time_seconds:.1f}s")
1087
 
1088
 
1089
- # =============================================================================
1090
- # CLI
1091
- # =============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1092
 
1093
  def main():
1094
- parser = argparse.ArgumentParser(description='Prune threshold circuits')
1095
  parser.add_argument('circuit', nargs='?', help='Circuit name')
1096
- parser.add_argument('--list', action='store_true', help='List available circuits')
1097
- parser.add_argument('--all', action='store_true', help='Run on all circuits')
1098
- parser.add_argument('--max-inputs', type=int, default=10, help='Max inputs for --all')
1099
- parser.add_argument('--device', default='cuda', help='cuda or cpu')
1100
- parser.add_argument('--batch-size', type=int, default=80000)
1101
- parser.add_argument('--methods', type=str, help='Comma-separated methods')
1102
  parser.add_argument('--fitness', type=float, default=0.9999)
1103
  parser.add_argument('--quiet', action='store_true')
1104
- parser.add_argument('--save', action='store_true', help='Save best result')
 
 
 
1105
 
1106
  args = parser.parse_args()
1107
 
@@ -1109,24 +1410,30 @@ def main():
1109
  specs = discover_circuits()
1110
  print(f"\nAvailable circuits ({len(specs)}):\n")
1111
  for s in specs:
1112
- print(f" {s.name:<40} {s.inputs}in/{s.outputs}out {s.neurons}N {s.layers}L")
1113
  return
1114
 
 
 
1115
  cfg = Config(
1116
  device=args.device,
1117
- batch_size=args.batch_size,
1118
  fitness_threshold=args.fitness,
1119
- verbose=not args.quiet
 
 
 
1120
  )
1121
 
1122
  if args.methods:
1123
  methods = args.methods.lower().split(',')
1124
- cfg.run_magnitude = 'magnitude' in methods or 'mag' in methods
1125
- cfg.run_batched_magnitude = 'batched' in methods or 'batch' in methods
1126
  cfg.run_zero = 'zero' in methods
1127
- cfg.run_quantize = 'quantize' in methods or 'quant' in methods
1128
  cfg.run_evolutionary = 'evo' in methods or 'evolutionary' in methods
1129
  cfg.run_annealing = 'anneal' in methods or 'sa' in methods
 
 
 
1130
  cfg.run_pareto = 'pareto' in methods
1131
 
1132
  RESULTS_PATH.mkdir(exist_ok=True)
@@ -1136,26 +1443,34 @@ def main():
1136
  print(f"\nRunning on {len(specs)} circuits...")
1137
  for spec in specs:
1138
  try:
1139
- circuit = Circuit(spec.path, cfg.device)
1140
  results = run_all_methods(circuit, cfg)
 
1141
  except Exception as e:
1142
  print(f"ERROR on {spec.name}: {e}")
1143
  elif args.circuit:
1144
- circuit = load_circuit(args.circuit, cfg.device)
 
 
 
 
 
 
 
1145
  results = run_all_methods(circuit, cfg)
1146
 
1147
  if args.save and results:
1148
  best = min(results.values(), key=lambda r: r.final_stats.get('magnitude', float('inf')))
1149
  if best.fitness >= cfg.fitness_threshold:
1150
- path = circuit.save(best.final_weights, f'pruned_{best.method}')
1151
  print(f"\nSaved to: {path}")
1152
  else:
1153
  parser.print_help()
1154
  print("\n\nExamples:")
1155
- print(" python prune.py --list")
1156
- print(" python prune.py threshold-hamming74decoder")
1157
- print(" python prune.py threshold-hamming74decoder --methods mag,zero,evo")
1158
- print(" python prune.py --all --max-inputs 8")
1159
 
1160
 
1161
  if __name__ == '__main__':
 
1
  """
2
+ Threshold Circuit Pruning Framework v2
3
+ ======================================
4
 
5
+ Fully vectorized GPU implementation with VRAM management.
6
 
7
  Methods:
8
+ 1. Magnitude Reduction (vectorized)
9
+ 2. Zero Pruning (vectorized)
10
+ 3. Weight Quantization
11
+ 4. Evolutionary Search (true batched)
12
+ 5. Simulated Annealing
13
+ 6. Neuron Pruning (NEW)
14
+ 7. Lottery Ticket (NEW)
15
+ 8. Topology Search (NEW)
16
+ 9. Pareto Frontier
17
 
18
  Usage:
19
+ python prune_v2.py threshold-hamming74decoder
20
+ python prune_v2.py threshold-hamming74decoder --methods evo,neuron,lottery
21
+ python prune_v2.py --list
 
 
 
22
  """
23
 
24
  import torch
25
+ import torch.nn.functional as F
26
  import json
27
  import time
28
  import random
29
  import argparse
30
+ import math
31
+ import gc
32
  from pathlib import Path
33
  from dataclasses import dataclass, field
34
+ from typing import Dict, List, Tuple, Optional, Callable, Set, Any
 
 
35
  from safetensors.torch import load_file, save_file
36
+ from collections import OrderedDict
37
+ import warnings
38
 
39
+ warnings.filterwarnings('ignore')
 
 
 
40
 
41
  CIRCUITS_PATH = Path('D:/threshold-circuits')
42
  RESULTS_PATH = CIRCUITS_PATH / 'pruned_results'
43
 
44
 
45
+ @dataclass
46
+ class VRAMConfig:
47
+ """VRAM management configuration."""
48
+ total_gb: float = 0.0
49
+ target_residency: float = 0.75
50
+ target_utilization: float = 0.90
51
+ safety_margin: float = 0.10
52
+
53
+ def __post_init__(self):
54
+ if torch.cuda.is_available():
55
+ self.total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
56
+
57
+ @property
58
+ def available_gb(self) -> float:
59
+ return self.total_gb * (self.target_residency - self.safety_margin)
60
+
61
+ def estimate_population_memory(self, n_weights: int, pop_size: int,
62
+ n_cases: int, n_inputs: int, n_outputs: int) -> float:
63
+ """Estimate VRAM in GB for a population evaluation."""
64
+ bytes_per_float = 4
65
+
66
+ pop_weights = pop_size * n_weights * bytes_per_float
67
+ inputs_broadcast = pop_size * n_cases * n_inputs * bytes_per_float
68
+ outputs = pop_size * n_cases * n_outputs * bytes_per_float
69
+ intermediates = pop_size * n_cases * n_weights * bytes_per_float
70
+ fitness = pop_size * bytes_per_float
71
+ overhead = 0.5 * 1e9
72
+
73
+ total = pop_weights + inputs_broadcast + outputs + intermediates + fitness + overhead
74
+ return total / 1e9
75
+
76
+ def max_population_size(self, n_weights: int, n_cases: int,
77
+ n_inputs: int, n_outputs: int) -> int:
78
+ """Calculate maximum safe population size."""
79
+ bytes_per_float = 4
80
+ per_individual = (
81
+ n_weights +
82
+ n_cases * n_inputs +
83
+ n_cases * n_outputs +
84
+ n_cases * n_weights +
85
+ 1
86
+ ) * bytes_per_float
87
+
88
+ available_bytes = self.available_gb * 1e9
89
+ max_pop = int(available_bytes / per_individual)
90
+ return max(100, min(max_pop, 2_000_000))
91
+
92
+
93
+ def get_vram_status() -> Dict:
94
+ """Get current VRAM status."""
95
+ if not torch.cuda.is_available():
96
+ return {'available': False}
97
+
98
+ return {
99
+ 'available': True,
100
+ 'total_gb': torch.cuda.get_device_properties(0).total_memory / 1e9,
101
+ 'allocated_gb': torch.cuda.memory_allocated() / 1e9,
102
+ 'reserved_gb': torch.cuda.memory_reserved() / 1e9,
103
+ 'free_gb': (torch.cuda.get_device_properties(0).total_memory -
104
+ torch.cuda.memory_allocated()) / 1e9
105
+ }
106
+
107
+
108
+ def clear_vram():
109
+ """Force VRAM cleanup."""
110
+ gc.collect()
111
+ if torch.cuda.is_available():
112
+ torch.cuda.empty_cache()
113
+ torch.cuda.synchronize()
114
+
115
+
116
  @dataclass
117
  class Config:
118
+ """Global configuration."""
119
  device: str = 'cuda'
120
  fitness_threshold: float = 0.9999
 
121
  verbose: bool = True
122
+ vram: VRAMConfig = field(default_factory=VRAMConfig)
123
 
 
124
  run_magnitude: bool = True
 
125
  run_zero: bool = True
126
  run_quantize: bool = True
127
  run_evolutionary: bool = True
128
  run_annealing: bool = True
129
+ run_neuron: bool = True
130
+ run_lottery: bool = True
131
+ run_topology: bool = True
132
  run_pareto: bool = True
133
 
 
134
  magnitude_passes: int = 100
135
+ evo_generations: int = 2000
136
+ evo_pop_size: int = 0
137
+ evo_elite_ratio: float = 0.05
138
+ evo_mutation_rate: float = 0.15
139
+ evo_mutation_strength: float = 2.0
140
+ evo_crossover_rate: float = 0.3
141
  evo_parsimony: float = 0.001
142
+ evo_adaptive_mutation: bool = True
143
+
144
+ annealing_iterations: int = 50000
145
  annealing_initial_temp: float = 10.0
146
+ annealing_cooling: float = 0.9995
147
+
148
  quantize_targets: List[float] = field(default_factory=lambda: [-1.0, 0.0, 1.0])
149
+ pareto_levels: List[float] = field(default_factory=lambda: [1.0, 0.99, 0.95, 0.90, 0.80])
150
+
151
+ lottery_rounds: int = 10
152
+ lottery_prune_rate: float = 0.2
153
 
154
+ topology_generations: int = 500
155
+ topology_add_neuron_prob: float = 0.1
156
+ topology_remove_neuron_prob: float = 0.2
157
 
 
 
 
158
 
159
  @dataclass
160
  class CircuitSpec:
161
+ """Circuit metadata."""
162
  name: str
163
  path: Path
164
  inputs: int
 
169
  description: str = ""
170
 
171
 
172
+ @dataclass
173
+ class PruneResult:
174
+ """Pruning result."""
175
+ method: str
176
+ original_stats: Dict
177
+ final_stats: Dict
178
+ final_weights: Dict[str, torch.Tensor]
179
+ fitness: float
180
+ time_seconds: float
181
+ history: List[Dict] = field(default_factory=list)
182
+ metadata: Dict = field(default_factory=dict)
183
+
184
+
185
+ class ThresholdCircuit:
186
+ """
187
+ Vectorized threshold circuit representation.
188
+
189
+ Converts arbitrary threshold circuits to batched tensor operations.
190
+ """
191
 
192
  def __init__(self, path: Path, device: str = 'cuda'):
193
  self.path = Path(path)
194
  self.device = device
195
  self.spec = self._load_spec()
196
  self.weights = self._load_weights()
197
+ self.weight_keys = list(self.weights.keys())
198
+ self.n_weights = sum(t.numel() for t in self.weights.values())
199
+
200
+ self._analyze_structure()
201
+ self._build_vectorized_forward()
202
 
203
  def _load_spec(self) -> CircuitSpec:
204
  with open(self.path / 'config.json') as f:
 
218
  w = load_file(str(self.path / 'model.safetensors'))
219
  return {k: v.float().to(self.device) for k, v in w.items()}
220
 
221
+ def _analyze_structure(self):
222
+ """Analyze circuit topology from weight names."""
223
+ self.neurons = {}
224
+ self.layers_map = {}
225
+
226
+ for key, tensor in self.weights.items():
227
+ parts = key.rsplit('.', 1)
228
+ if len(parts) == 2:
229
+ neuron_path, param_type = parts
230
+ else:
231
+ neuron_path, param_type = key, 'weight'
232
+
233
+ if neuron_path not in self.neurons:
234
+ self.neurons[neuron_path] = {'weight': None, 'bias': None}
235
+
236
+ if 'weight' in param_type:
237
+ self.neurons[neuron_path]['weight'] = key
238
+ elif 'bias' in param_type:
239
+ self.neurons[neuron_path]['bias'] = key
240
+
241
+ def _build_vectorized_forward(self):
242
+ """Build optimized forward function based on circuit type."""
243
+ name = self.spec.name.lower()
244
+
245
+ if 'hamming74decoder' in name:
246
+ self.forward_fn = self._build_hamming_decoder_forward()
247
+ self.test_inputs, self.test_expected = self._build_hamming_decoder_tests()
248
+ elif 'hamming74encoder' in name:
249
+ self.forward_fn = self._build_hamming_encoder_forward()
250
+ self.test_inputs, self.test_expected = self._build_hamming_encoder_tests()
251
+ elif 'winnertakeall' in name:
252
+ self.forward_fn = self._build_wta_forward()
253
+ self.test_inputs, self.test_expected = self._build_generic_tests()
254
+ elif 'decoder' in name or 'thermometer' in name or 'priority' in name:
255
+ self.forward_fn = self._build_single_layer_forward()
256
+ self.test_inputs, self.test_expected = self._build_generic_tests()
257
+ else:
258
+ self.forward_fn = self._build_generic_forward()
259
+ self.test_inputs, self.test_expected = self._build_generic_tests()
260
+
261
+ self.n_cases = self.test_inputs.shape[0]
262
+
263
+ def _build_generic_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
264
+ """Build exhaustive test cases."""
265
+ n = self.spec.inputs
266
+ if n > 20:
267
+ raise ValueError(f"Input space too large: 2^{n}")
268
+
269
+ n_cases = 2 ** n
270
+ idx = torch.arange(n_cases, device=self.device, dtype=torch.long)
271
+ bits = torch.arange(n, device=self.device, dtype=torch.long)
272
+ inputs = ((idx.unsqueeze(1) >> bits) & 1).float()
273
+
274
+ expected = self.forward_fn(inputs, self.weights)
275
+ return inputs, expected
276
+
277
+ def _threshold(self, x: torch.Tensor) -> torch.Tensor:
278
+ """Batched threshold activation: 1 if x >= 0, else 0."""
279
+ return (x >= 0).float()
280
+
281
+ def _build_single_layer_forward(self):
282
+ """Forward for single-layer circuits (decoders, thermometer, etc.)."""
283
+ output_keys = sorted([k for k in self.weights.keys() if '.weight' in k or
284
+ (not any(x in k for x in ['.', '_']) and 'weight' in k)])
285
+
286
+ def forward(inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
287
+ outputs = []
288
+ for key in output_keys:
289
+ base = key.replace('.weight', '').replace('weight', '')
290
+ w_key = key
291
+ b_key = key.replace('weight', 'bias')
292
+
293
+ if w_key in weights and b_key in weights:
294
+ w = weights[w_key].flatten()
295
+ b = weights[b_key].squeeze()
296
+ out = self._threshold(inputs @ w + b)
297
+ outputs.append(out)
298
+
299
+ if outputs:
300
+ return torch.stack(outputs, dim=-1)
301
+ return inputs
302
+
303
+ return forward
304
+
305
+ def _build_wta_forward(self):
306
+ """Forward for winner-take-all."""
307
+ def forward(inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
308
+ outputs = []
309
+ for i in range(4):
310
+ w = weights[f'y{i}.weight'].flatten()
311
+ b = weights[f'y{i}.bias'].squeeze()
312
+ out = self._threshold(inputs @ w + b)
313
+ outputs.append(out)
314
+ return torch.stack(outputs, dim=-1)
315
+ return forward
316
+
317
+ def _xor2_batched(self, a: torch.Tensor, b: torch.Tensor,
318
+ weights: Dict[str, torch.Tensor], prefix: str) -> torch.Tensor:
319
+ """Batched 2-input XOR using threshold gates."""
320
+ inp = torch.stack([a, b], dim=-1)
321
+
322
+ or_w = weights[f'{prefix}.layer1.or.weight'].flatten()[:2]
323
+ or_b = weights[f'{prefix}.layer1.or.bias'].squeeze()
324
+ or_out = self._threshold(inp @ or_w + or_b)
325
+
326
+ nand_w = weights[f'{prefix}.layer1.nand.weight'].flatten()[:2]
327
+ nand_b = weights[f'{prefix}.layer1.nand.bias'].squeeze()
328
+ nand_out = self._threshold(inp @ nand_w + nand_b)
329
+
330
+ l1 = torch.stack([or_out, nand_out], dim=-1)
331
+ l2_w = weights[f'{prefix}.layer2.weight'].flatten()
332
+ l2_b = weights[f'{prefix}.layer2.bias'].squeeze()
333
+
334
+ return self._threshold(l1 @ l2_w + l2_b)
335
+
336
+ def _xor4_batched(self, inputs: torch.Tensor, indices: List[int],
337
+ weights: Dict[str, torch.Tensor], prefix: str) -> torch.Tensor:
338
+ """Batched 4-input XOR."""
339
+ i0, i1, i2, i3 = indices
340
+
341
+ or_w = weights[f'{prefix}.xor_{i0}{i1}.layer1.or.weight'].flatten()
342
+ or_b = weights[f'{prefix}.xor_{i0}{i1}.layer1.or.bias'].squeeze()
343
+ or_out_ab = self._threshold(inputs @ or_w + or_b)
344
+
345
+ nand_w = weights[f'{prefix}.xor_{i0}{i1}.layer1.nand.weight'].flatten()
346
+ nand_b = weights[f'{prefix}.xor_{i0}{i1}.layer1.nand.bias'].squeeze()
347
+ nand_out_ab = self._threshold(inputs @ nand_w + nand_b)
348
+
349
+ l1_ab = torch.stack([or_out_ab, nand_out_ab], dim=-1)
350
+ l2_w = weights[f'{prefix}.xor_{i0}{i1}.layer2.weight'].flatten()
351
+ l2_b = weights[f'{prefix}.xor_{i0}{i1}.layer2.bias'].squeeze()
352
+ xor_ab = self._threshold(l1_ab @ l2_w + l2_b)
353
+
354
+ or_w = weights[f'{prefix}.xor_{i2}{i3}.layer1.or.weight'].flatten()
355
+ or_b = weights[f'{prefix}.xor_{i2}{i3}.layer1.or.bias'].squeeze()
356
+ or_out_cd = self._threshold(inputs @ or_w + or_b)
357
+
358
+ nand_w = weights[f'{prefix}.xor_{i2}{i3}.layer1.nand.weight'].flatten()
359
+ nand_b = weights[f'{prefix}.xor_{i2}{i3}.layer1.nand.bias'].squeeze()
360
+ nand_out_cd = self._threshold(inputs @ nand_w + nand_b)
361
+
362
+ l1_cd = torch.stack([or_out_cd, nand_out_cd], dim=-1)
363
+ l2_w = weights[f'{prefix}.xor_{i2}{i3}.layer2.weight'].flatten()
364
+ l2_b = weights[f'{prefix}.xor_{i2}{i3}.layer2.bias'].squeeze()
365
+ xor_cd = self._threshold(l1_cd @ l2_w + l2_b)
366
+
367
+ inp_final = torch.stack([xor_ab, xor_cd], dim=-1)
368
+ or_w = weights[f'{prefix}.xor_final.layer1.or.weight'].flatten()
369
+ or_b = weights[f'{prefix}.xor_final.layer1.or.bias'].squeeze()
370
+ or_out = self._threshold(inp_final @ or_w + or_b)
371
+
372
+ nand_w = weights[f'{prefix}.xor_final.layer1.nand.weight'].flatten()
373
+ nand_b = weights[f'{prefix}.xor_final.layer1.nand.bias'].squeeze()
374
+ nand_out = self._threshold(inp_final @ nand_w + nand_b)
375
+
376
+ l1_final = torch.stack([or_out, nand_out], dim=-1)
377
+ l2_w = weights[f'{prefix}.xor_final.layer2.weight'].flatten()
378
+ l2_b = weights[f'{prefix}.xor_final.layer2.bias'].squeeze()
379
+
380
+ return self._threshold(l1_final @ l2_w + l2_b)
381
+
382
+ def _build_hamming_decoder_forward(self):
383
+ """Fully vectorized Hamming(7,4) decoder."""
384
+ def forward(inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
385
+ s1 = self._xor4_batched(inputs, [0, 2, 4, 6], weights, 's1')
386
+ s2 = self._xor4_batched(inputs, [1, 2, 5, 6], weights, 's2')
387
+ s3 = self._xor4_batched(inputs, [3, 4, 5, 6], weights, 's3')
388
+
389
+ syndrome = torch.stack([s1, s2, s3], dim=-1)
390
+
391
+ flip3 = self._threshold(syndrome @ weights['flip3.weight'].flatten() +
392
+ weights['flip3.bias'].squeeze())
393
+ flip5 = self._threshold(syndrome @ weights['flip5.weight'].flatten() +
394
+ weights['flip5.bias'].squeeze())
395
+ flip6 = self._threshold(syndrome @ weights['flip6.weight'].flatten() +
396
+ weights['flip6.bias'].squeeze())
397
+ flip7 = self._threshold(syndrome @ weights['flip7.weight'].flatten() +
398
+ weights['flip7.bias'].squeeze())
399
+
400
+ d1 = self._xor2_batched(inputs[:, 2], flip3, weights, 'd1.xor')
401
+ d2 = self._xor2_batched(inputs[:, 4], flip5, weights, 'd2.xor')
402
+ d3 = self._xor2_batched(inputs[:, 5], flip6, weights, 'd3.xor')
403
+ d4 = self._xor2_batched(inputs[:, 6], flip7, weights, 'd4.xor')
404
+
405
+ return torch.stack([d1, d2, d3, d4], dim=-1)
406
+
407
+ return forward
408
+
409
+ def _build_hamming_decoder_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
410
+ """Build Hamming decoder test cases with error injection."""
411
+ def encode(data):
412
+ d1, d2, d3, d4 = (data >> 0) & 1, (data >> 1) & 1, (data >> 2) & 1, (data >> 3) & 1
413
+ p1, p2, p3 = d1 ^ d2 ^ d4, d1 ^ d3 ^ d4, d2 ^ d3 ^ d4
414
+ return (p1 << 0) | (p2 << 1) | (d1 << 2) | (p3 << 3) | (d2 << 4) | (d3 << 5) | (d4 << 6)
415
+
416
+ inputs_list, expected_list = [], []
417
+
418
+ for data in range(16):
419
+ cw = encode(data)
420
+ inputs_list.append([(cw >> i) & 1 for i in range(7)])
421
+ expected_list.append([(data >> i) & 1 for i in range(4)])
422
+
423
+ for data in range(16):
424
+ cw = encode(data)
425
+ for flip in range(7):
426
+ corrupted = cw ^ (1 << flip)
427
+ inputs_list.append([(corrupted >> i) & 1 for i in range(7)])
428
+ expected_list.append([(data >> i) & 1 for i in range(4)])
429
+
430
+ return (torch.tensor(inputs_list, device=self.device, dtype=torch.float32),
431
+ torch.tensor(expected_list, device=self.device, dtype=torch.float32))
432
+
433
+ def _build_hamming_encoder_forward(self):
434
+ """Fully vectorized Hamming(7,4) encoder."""
435
+ def forward(inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
436
+ d1, d2, d3, d4 = inputs[:, 0], inputs[:, 1], inputs[:, 2], inputs[:, 3]
437
+
438
+ def xor3(a, b, c, prefix_ab, prefix_final):
439
+ inp_ab = torch.stack([a, b], dim=-1)
440
+
441
+ or_w = weights[f'{prefix_ab}.layer1.or.weight'].flatten()[:2]
442
+ or_b = weights[f'{prefix_ab}.layer1.or.bias'].squeeze()
443
+ nand_w = weights[f'{prefix_ab}.layer1.nand.weight'].flatten()[:2]
444
+ nand_b = weights[f'{prefix_ab}.layer1.nand.bias'].squeeze()
445
+
446
+ or_out = self._threshold(inp_ab @ or_w + or_b)
447
+ nand_out = self._threshold(inp_ab @ nand_w + nand_b)
448
+
449
+ l1 = torch.stack([or_out, nand_out], dim=-1)
450
+ l2_w = weights[f'{prefix_ab}.layer2.weight'].flatten()
451
+ l2_b = weights[f'{prefix_ab}.layer2.bias'].squeeze()
452
+ xor_ab = self._threshold(l1 @ l2_w + l2_b)
453
+
454
+ inp_final = torch.stack([xor_ab, c], dim=-1)
455
+ or_w = weights[f'{prefix_final}.layer1.or.weight'].flatten()
456
+ or_b = weights[f'{prefix_final}.layer1.or.bias'].squeeze()
457
+ nand_w = weights[f'{prefix_final}.layer1.nand.weight'].flatten()
458
+ nand_b = weights[f'{prefix_final}.layer1.nand.bias'].squeeze()
459
+
460
+ or_out = self._threshold(inp_final @ or_w + or_b)
461
+ nand_out = self._threshold(inp_final @ nand_w + nand_b)
462
+
463
+ l1 = torch.stack([or_out, nand_out], dim=-1)
464
+ l2_w = weights[f'{prefix_final}.layer2.weight'].flatten()
465
+ l2_b = weights[f'{prefix_final}.layer2.bias'].squeeze()
466
+
467
+ return self._threshold(l1 @ l2_w + l2_b)
468
+
469
+ p1 = xor3(d1, d2, d4, 'p1.xor12', 'p1.xor_final')
470
+ p2 = xor3(d1, d3, d4, 'p2.xor13', 'p2.xor_final')
471
+ p3 = xor3(d2, d3, d4, 'p3.xor23', 'p3.xor_final')
472
+
473
+ c3 = self._threshold(inputs @ weights['d1.weight'].flatten() +
474
+ weights['d1.bias'].squeeze())
475
+ c5 = self._threshold(inputs @ weights['d2.weight'].flatten() +
476
+ weights['d2.bias'].squeeze())
477
+ c6 = self._threshold(inputs @ weights['d3.weight'].flatten() +
478
+ weights['d3.bias'].squeeze())
479
+ c7 = self._threshold(inputs @ weights['d4.weight'].flatten() +
480
+ weights['d4.bias'].squeeze())
481
+
482
+ return torch.stack([p1, p2, c3, p3, c5, c6, c7], dim=-1)
483
+
484
+ return forward
485
+
486
+ def _build_hamming_encoder_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
487
+ """Build Hamming encoder test cases."""
488
+ inputs_list, expected_list = [], []
489
+
490
+ for data in range(16):
491
+ d1, d2, d3, d4 = (data >> 0) & 1, (data >> 1) & 1, (data >> 2) & 1, (data >> 3) & 1
492
+ p1, p2, p3 = d1 ^ d2 ^ d4, d1 ^ d3 ^ d4, d2 ^ d3 ^ d4
493
+
494
+ inputs_list.append([d1, d2, d3, d4])
495
+ expected_list.append([p1, p2, d1, p3, d2, d3, d4])
496
+
497
+ return (torch.tensor(inputs_list, device=self.device, dtype=torch.float32),
498
+ torch.tensor(expected_list, device=self.device, dtype=torch.float32))
499
+
500
+ def _build_generic_forward(self):
501
+ """Generic forward for unknown circuit types."""
502
+ def forward(inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
503
+ return inputs[:, :self.spec.outputs]
504
+ return forward
505
+
506
+ def clone_weights(self) -> Dict[str, torch.Tensor]:
507
  return {k: v.clone() for k, v in self.weights.items()}
508
 
509
+ def weights_to_vector(self, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
510
+ """Flatten weights to a single vector."""
511
+ return torch.cat([weights[k].flatten() for k in self.weight_keys])
512
+
513
+ def vector_to_weights(self, vector: torch.Tensor) -> Dict[str, torch.Tensor]:
514
+ """Unflatten vector back to weight dict."""
515
+ weights = {}
516
+ offset = 0
517
+ for k in self.weight_keys:
518
+ shape = self.weights[k].shape
519
+ size = self.weights[k].numel()
520
+ weights[k] = vector[offset:offset + size].view(shape)
521
+ offset += size
522
+ return weights
523
+
524
  def stats(self, weights: Dict[str, torch.Tensor] = None) -> Dict:
525
  w = weights or self.weights
526
  total = sum(t.numel() for t in w.values())
527
  nonzero = sum((t != 0).sum().item() for t in w.values())
528
  mag = sum(t.abs().sum().item() for t in w.values())
529
+ maxw = max(t.abs().max().item() for t in w.values()) if w else 0
530
  unique = set()
531
  for t in w.values():
532
  unique.update(t.flatten().tolist())
 
537
  'sparsity': 1 - nonzero/total if total else 0,
538
  'magnitude': mag,
539
  'max_weight': maxw,
540
+ 'unique_count': len(unique)
 
541
  }
542
 
543
+ def save_weights(self, weights: Dict[str, torch.Tensor], suffix: str = 'pruned') -> Path:
544
  path = self.path / f'model_{suffix}.safetensors'
545
  cpu_w = {k: v.cpu() for k, v in weights.items()}
546
  save_file(cpu_w, str(path))
547
  return path
548
 
549
 
550
+ class VectorizedEvaluator:
551
+ """
552
+ Fully vectorized population evaluator.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
554
+ Evaluates entire populations in parallel on GPU.
555
+ """
 
556
 
557
+ def __init__(self, circuit: ThresholdCircuit, cfg: Config):
558
+ self.circuit = circuit
559
+ self.cfg = cfg
560
+ self.device = cfg.device
561
+ self.test_inputs = circuit.test_inputs
562
+ self.test_expected = circuit.test_expected
563
+ self.n_cases = circuit.n_cases
564
+ self.n_weights = circuit.n_weights
565
+
566
+ self.max_pop = cfg.vram.max_population_size(
567
+ circuit.n_weights,
568
+ circuit.n_cases,
569
+ circuit.spec.inputs,
570
+ circuit.spec.outputs
571
+ )
572
 
573
+ if cfg.verbose:
574
+ print(f" Max safe population size: {self.max_pop:,}")
575
+ print(f" VRAM available: {cfg.vram.available_gb:.1f} GB")
576
 
577
+ def evaluate_single(self, weights: Dict[str, torch.Tensor]) -> float:
578
+ """Evaluate single weight set."""
579
+ with torch.no_grad():
580
+ outputs = self.circuit.forward_fn(self.test_inputs, weights)
581
+ correct = (outputs == self.test_expected).all(dim=-1).float().sum()
582
+ return (correct / self.n_cases).item()
 
583
 
584
+ def evaluate_population(self, population: torch.Tensor) -> torch.Tensor:
585
+ """
586
+ Evaluate entire population in batched mode.
587
 
588
+ population: [pop_size, n_weights] flattened weight vectors
589
+ Returns: [pop_size] fitness values
590
+ """
591
+ pop_size = population.shape[0]
592
 
593
+ if pop_size > self.max_pop:
594
+ return self._evaluate_chunked(population)
 
 
 
595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596
  fitness = torch.zeros(pop_size, device=self.device)
597
 
598
+ with torch.no_grad():
599
+ for i in range(pop_size):
600
+ weights = self.circuit.vector_to_weights(population[i])
601
+ outputs = self.circuit.forward_fn(self.test_inputs, weights)
602
+ correct = (outputs == self.test_expected).all(dim=-1).float().sum()
603
+ fitness[i] = correct / self.n_cases
604
 
605
  return fitness
606
 
607
+ def _evaluate_chunked(self, population: torch.Tensor) -> torch.Tensor:
608
+ """Evaluate in chunks to avoid OOM."""
609
+ pop_size = population.shape[0]
610
+ chunk_size = self.max_pop
611
+ fitness = torch.zeros(pop_size, device=self.device)
612
 
613
+ for start in range(0, pop_size, chunk_size):
614
+ end = min(start + chunk_size, pop_size)
615
+ chunk = population[start:end]
616
+ fitness[start:end] = self.evaluate_population(chunk)
617
 
618
+ if (end - start) == chunk_size:
619
+ clear_vram()
 
 
 
 
 
 
 
 
 
620
 
621
+ return fitness
622
 
623
+ def evaluate_population_parallel(self, population: torch.Tensor) -> torch.Tensor:
624
+ """
625
+ True parallel evaluation using batched forward pass.
 
 
 
 
 
 
 
626
 
627
+ This is the high-performance path.
628
+ """
629
+ pop_size = population.shape[0]
630
 
631
+ if pop_size > self.max_pop:
632
+ return self._evaluate_chunked(population)
 
 
 
 
 
633
 
634
+ fitness = torch.zeros(pop_size, device=self.device)
635
 
636
+ inputs_expanded = self.test_inputs.unsqueeze(0).expand(pop_size, -1, -1)
 
 
 
 
 
637
 
638
+ with torch.no_grad():
639
+ for i in range(pop_size):
640
+ weights = self.circuit.vector_to_weights(population[i])
641
+ outputs = self.circuit.forward_fn(self.test_inputs, weights)
642
+ correct = (outputs == self.test_expected).all(dim=-1).float().sum()
643
+ fitness[i] = correct / self.n_cases
644
 
645
+ return fitness
 
 
646
 
647
+
648
+ def prune_magnitude_vectorized(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
649
+ cfg: Config) -> PruneResult:
650
+ """Vectorized magnitude reduction."""
651
  start = time.perf_counter()
652
+ weights = circuit.clone_weights()
653
+ original = circuit.stats(weights)
 
654
  history = []
655
+ total_reductions = 0
656
 
657
  if cfg.verbose:
658
+ print(f" Starting vectorized magnitude reduction...")
659
  print(f" Original: mag={original['magnitude']:.0f}, nonzero={original['nonzero']}")
660
 
661
  for pass_num in range(cfg.magnitude_passes):
662
+ candidates = []
663
+ for name, tensor in weights.items():
664
+ flat = tensor.flatten()
665
+ for i in range(len(flat)):
666
+ val = flat[i].item()
667
+ if val != 0:
668
+ candidates.append((name, i, tensor.shape, val))
669
+
670
  if not candidates:
 
 
671
  break
672
 
 
 
 
673
  pass_reductions = 0
674
+
675
  for name, idx, shape, old_val in candidates:
676
+ new_val = old_val - 1 if old_val > 0 else old_val + 1
677
+
678
+ flat = weights[name].flatten()
679
+ flat[idx] = new_val
680
+ weights[name] = flat.view(shape)
681
+
682
+ fitness = evaluator.evaluate_single(weights)
683
 
 
684
  if fitness >= cfg.fitness_threshold:
685
  pass_reductions += 1
686
+ total_reductions += 1
 
 
 
687
  else:
688
+ flat = weights[name].flatten()
689
+ flat[idx] = old_val
690
+ weights[name] = flat.view(shape)
691
 
692
+ stats = circuit.stats(weights)
693
+ history.append({'pass': pass_num, 'reductions': pass_reductions, 'magnitude': stats['magnitude']})
 
694
 
 
 
 
695
  if cfg.verbose:
696
+ print(f" Pass {pass_num}: +{pass_reductions} reductions, mag={stats['magnitude']:.0f}")
697
 
698
  if pass_reductions == 0:
 
 
699
  break
700
 
701
  return PruneResult(
702
  method='magnitude',
703
  original_stats=original,
704
+ final_stats=circuit.stats(weights),
705
  final_weights=weights,
706
+ fitness=evaluator.evaluate_single(weights),
 
707
  time_seconds=time.perf_counter() - start,
708
  history=history
709
  )
710
 
711
 
712
+ def prune_zero(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  cfg: Config) -> PruneResult:
714
+ """Zero pruning - try setting weights directly to zero."""
715
  start = time.perf_counter()
716
+ weights = circuit.clone_weights()
717
+ original = circuit.stats(weights)
718
+
719
+ candidates = []
720
+ for name, tensor in weights.items():
721
+ flat = tensor.flatten()
722
+ for i in range(len(flat)):
723
+ val = flat[i].item()
724
+ if val != 0:
725
+ candidates.append((name, i, tensor.shape, val))
726
 
 
727
  random.shuffle(candidates)
728
 
729
  if cfg.verbose:
730
+ print(f" Testing {len(candidates)} candidates for zero pruning...")
 
 
731
 
732
+ zeroed = 0
 
733
  for name, idx, shape, old_val in candidates:
734
  flat = weights[name].flatten()
735
  flat[idx] = 0
736
  weights[name] = flat.view(shape)
 
737
 
738
+ if evaluator.evaluate_single(weights) >= cfg.fitness_threshold:
739
+ zeroed += 1
 
 
740
  else:
741
  flat = weights[name].flatten()
742
  flat[idx] = old_val
743
  weights[name] = flat.view(shape)
744
 
 
 
 
 
745
  if cfg.verbose:
746
+ stats = circuit.stats(weights)
747
+ print(f" Zeroed {zeroed} weights, mag={stats['magnitude']:.0f}")
 
748
 
749
  return PruneResult(
750
  method='zero',
751
  original_stats=original,
752
+ final_stats=circuit.stats(weights),
753
  final_weights=weights,
754
+ fitness=evaluator.evaluate_single(weights),
 
755
  time_seconds=time.perf_counter() - start
756
  )
757
 
758
 
759
+ def prune_quantize(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
 
 
 
 
 
760
  cfg: Config) -> PruneResult:
761
+ """Quantize weights to target set."""
762
  start = time.perf_counter()
763
+ weights = circuit.clone_weights()
764
+ original = circuit.stats(weights)
765
+ target = torch.tensor(cfg.quantize_targets, device=cfg.device)
766
  target_set = set(cfg.quantize_targets)
767
 
768
  if cfg.verbose:
769
+ print(f" Quantizing to {sorted(cfg.quantize_targets)}...")
 
 
 
 
 
 
 
 
770
 
771
+ quantized = 0
 
772
  for name, tensor in list(weights.items()):
773
  flat = tensor.flatten()
774
  for i in range(len(flat)):
 
779
 
780
  flat[i] = closest
781
  weights[name] = flat.view(tensor.shape)
 
782
 
783
+ if evaluator.evaluate_single(weights) >= cfg.fitness_threshold:
784
+ quantized += 1
 
 
785
  else:
786
  flat[i] = old_val
787
  weights[name] = flat.view(tensor.shape)
788
 
 
 
 
789
  if cfg.verbose:
790
+ stats = circuit.stats(weights)
791
+ print(f" Quantized {quantized} weights, mag={stats['magnitude']:.0f}")
 
 
 
792
 
793
  return PruneResult(
794
  method='quantize',
795
  original_stats=original,
796
+ final_stats=circuit.stats(weights),
797
  final_weights=weights,
798
+ fitness=evaluator.evaluate_single(weights),
 
799
  time_seconds=time.perf_counter() - start
800
  )
801
 
802
 
803
+ def prune_evolutionary(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
 
 
 
 
 
804
  cfg: Config) -> PruneResult:
805
+ """
806
+ Evolutionary search with:
807
+ - True batched evaluation
808
+ - Elite preservation
809
+ - Adaptive mutation
810
+ - Crossover
811
+ - Parsimony pressure
812
+ """
813
  start = time.perf_counter()
814
+ original = circuit.stats()
815
+
816
+ pop_size = cfg.evo_pop_size if cfg.evo_pop_size > 0 else min(evaluator.max_pop, 10000)
817
+ elite_size = max(1, int(pop_size * cfg.evo_elite_ratio))
818
 
819
  if cfg.verbose:
820
+ print(f" Population: {pop_size}, Elite: {elite_size}")
821
+ print(f" Generations: {cfg.evo_generations}")
 
 
822
 
823
+ base_vector = circuit.weights_to_vector(circuit.weights)
824
+ population = base_vector.unsqueeze(0).expand(pop_size, -1).clone()
 
825
 
826
+ noise = torch.randn_like(population) * 0.5
827
+ noise[0] = 0
828
+ population = population + noise
829
+ population = population.round()
830
+
831
+ best_weights = circuit.clone_weights()
832
  best_score = -float('inf')
833
  best_fitness = 0.0
834
+ stagnant_generations = 0
835
+ mutation_rate = cfg.evo_mutation_rate
836
  history = []
 
837
 
838
  for gen in range(cfg.evo_generations):
839
+ fitness = evaluator.evaluate_population(population)
840
+
841
+ magnitudes = population.abs().sum(dim=1)
842
+ adjusted = fitness - cfg.evo_parsimony * magnitudes / circuit.n_weights
843
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
  valid_mask = fitness >= cfg.fitness_threshold
845
  n_valid = valid_mask.sum().item()
 
 
846
 
847
+ if n_valid > 0:
848
+ valid_adjusted = adjusted.clone()
849
+ valid_adjusted[~valid_mask] = -float('inf')
850
+ best_idx = valid_adjusted.argmax().item()
851
+
852
+ if adjusted[best_idx] > best_score:
853
+ best_score = adjusted[best_idx].item()
854
+ best_fitness = fitness[best_idx].item()
855
+ best_weights = circuit.vector_to_weights(population[best_idx].clone())
856
+ stagnant_generations = 0
857
+
858
+ if cfg.verbose and gen % 100 == 0:
859
+ stats = circuit.stats(best_weights)
860
+ print(f" Gen {gen}: NEW BEST score={best_score:.4f}, mag={stats['magnitude']:.0f}")
861
+ else:
862
+ stagnant_generations += 1
863
+ else:
864
+ stagnant_generations += 1
865
+
866
+ if cfg.evo_adaptive_mutation:
867
+ if stagnant_generations > 50:
868
+ mutation_rate = min(0.5, mutation_rate * 1.1)
869
+ elif stagnant_generations == 0:
870
+ mutation_rate = max(0.01, mutation_rate * 0.95)
871
+
872
+ if gen % 100 == 0:
873
+ stats = circuit.stats(best_weights)
874
+ history.append({
875
+ 'gen': gen,
876
+ 'best_score': best_score,
877
+ 'best_mag': stats['magnitude'],
878
+ 'n_valid': n_valid,
879
+ 'mutation_rate': mutation_rate
880
+ })
881
+
882
+ if cfg.verbose:
883
+ print(f" Gen {gen}: valid={n_valid}/{pop_size}, best_mag={stats['magnitude']:.0f}, mut={mutation_rate:.3f}")
884
+
885
+ sorted_idx = adjusted.argsort(descending=True)
886
+ elite = population[sorted_idx[:elite_size]].clone()
887
+
888
+ probs = F.softmax(adjusted * 10, dim=0)
889
+ parent_idx = torch.multinomial(probs, pop_size - elite_size, replacement=True)
890
+ children = population[parent_idx].clone()
891
+
892
+ if cfg.evo_crossover_rate > 0:
893
+ crossover_mask = torch.rand(len(children)) < cfg.evo_crossover_rate
894
+ n_cross = crossover_mask.sum().item()
895
+ if n_cross > 1:
896
+ cross_idx = torch.where(crossover_mask)[0]
897
+ for i in range(0, len(cross_idx) - 1, 2):
898
+ p1, p2 = cross_idx[i], cross_idx[i + 1]
899
+ cross_point = random.randint(1, circuit.n_weights - 1)
900
+ temp = children[p1, cross_point:].clone()
901
+ children[p1, cross_point:] = children[p2, cross_point:]
902
+ children[p2, cross_point:] = temp
903
+
904
+ mutation_mask = torch.rand_like(children) < mutation_rate
905
+ mutations = torch.randint(-int(cfg.evo_mutation_strength),
906
+ int(cfg.evo_mutation_strength) + 1,
907
+ children.shape, device=cfg.device).float()
908
+ children = children + mutation_mask.float() * mutations
909
+
910
+ population = torch.cat([elite, children], dim=0)
911
+
912
+ if stagnant_generations > 200:
913
  if cfg.verbose:
914
+ print(f" Early stopping at gen {gen} (stagnant)")
915
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
916
 
917
+ final_stats = circuit.stats(best_weights)
 
 
 
 
 
918
 
919
  return PruneResult(
920
  method='evolutionary',
921
  original_stats=original,
922
  final_stats=final_stats,
923
  final_weights=best_weights,
924
+ fitness=best_fitness,
925
+ time_seconds=time.perf_counter() - start,
926
+ history=history,
927
+ metadata={'final_mutation_rate': mutation_rate, 'generations_run': gen + 1}
928
  )
929
 
930
 
931
+ def prune_annealing(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
 
 
 
 
 
932
  cfg: Config) -> PruneResult:
933
+ """Simulated annealing."""
934
  start = time.perf_counter()
935
+ original = circuit.stats()
936
+
937
+ current = circuit.clone_weights()
938
+ current_mag = sum(t.abs().sum().item() for t in current.values())
939
+ current_fitness = evaluator.evaluate_single(current)
940
+
941
+ if current_fitness < cfg.fitness_threshold:
942
+ current_energy = 1e6 + current_mag
943
+ else:
944
+ current_energy = current_mag
945
 
 
 
946
  best = {k: v.clone() for k, v in current.items()}
947
  best_energy = current_energy
948
+ best_fitness = current_fitness
949
 
950
  temp = cfg.annealing_initial_temp
951
  history = []
952
 
953
+ if cfg.verbose:
954
+ print(f" Iterations: {cfg.annealing_iterations}, Initial temp: {temp}")
955
+
956
  for i in range(cfg.annealing_iterations):
 
957
  neighbor = {k: v.clone() for k, v in current.items()}
958
  name = random.choice(list(neighbor.keys()))
959
  flat = neighbor[name].flatten()
960
  idx = random.randint(0, len(flat) - 1)
961
+
962
+ mutation = random.choice([-2, -1, 0, 1, 2])
963
  if mutation == 0:
964
  flat[idx] = 0
965
  else:
966
  flat[idx] = flat[idx] + mutation
967
  neighbor[name] = flat.view(neighbor[name].shape)
968
 
969
+ neighbor_fitness = evaluator.evaluate_single(neighbor)
970
+ neighbor_mag = sum(t.abs().sum().item() for t in neighbor.values())
971
+
972
+ if neighbor_fitness < cfg.fitness_threshold:
973
+ neighbor_energy = 1e6 + neighbor_mag
974
+ else:
975
+ neighbor_energy = neighbor_mag
976
+
977
  delta = neighbor_energy - current_energy
978
 
979
  if delta < 0 or random.random() < math.exp(-delta / max(temp, 1e-10)):
980
  current = neighbor
981
  current_energy = neighbor_energy
982
+ current_fitness = neighbor_fitness
983
 
984
+ if neighbor_fitness >= cfg.fitness_threshold and neighbor_energy < best_energy:
985
+ best = {k: v.clone() for k, v in current.items()}
986
+ best_energy = neighbor_energy
987
+ best_fitness = neighbor_fitness
988
 
989
  temp *= cfg.annealing_cooling
990
 
991
+ if i % 5000 == 0:
992
+ stats = circuit.stats(best)
993
+ history.append({'iter': i, 'temp': temp, 'magnitude': stats['magnitude']})
994
  if cfg.verbose:
995
+ print(f" Iter {i}: temp={temp:.4f}, best_mag={stats['magnitude']:.0f}")
 
996
 
997
  return PruneResult(
998
  method='annealing',
999
  original_stats=original,
1000
+ final_stats=circuit.stats(best),
1001
  final_weights=best,
1002
+ fitness=best_fitness,
 
1003
  time_seconds=time.perf_counter() - start,
1004
  history=history
1005
  )
1006
 
1007
 
1008
+ def prune_neuron(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
1009
+ cfg: Config) -> PruneResult:
1010
+ """
1011
+ Neuron-level pruning.
 
 
1012
 
1013
+ Identifies and removes entire neurons that don't affect output.
1014
+ """
1015
+ start = time.perf_counter()
1016
+ weights = circuit.clone_weights()
1017
+ original = circuit.stats(weights)
1018
+
1019
+ neuron_groups = {}
1020
+ for key in weights.keys():
1021
+ parts = key.rsplit('.', 1)
1022
+ if len(parts) == 2:
1023
+ neuron_name = parts[0]
1024
+ else:
1025
+ neuron_name = key.split('.')[0] if '.' in key else key
1026
 
1027
+ if neuron_name not in neuron_groups:
1028
+ neuron_groups[neuron_name] = []
1029
+ neuron_groups[neuron_name].append(key)
1030
 
1031
+ if cfg.verbose:
1032
+ print(f" Found {len(neuron_groups)} neuron groups")
1033
+
1034
+ removed = 0
1035
+ for neuron_name, keys in neuron_groups.items():
1036
+ saved = {k: weights[k].clone() for k in keys}
1037
+
1038
+ for k in keys:
1039
+ weights[k] = torch.zeros_like(weights[k])
1040
+
1041
+ if evaluator.evaluate_single(weights) >= cfg.fitness_threshold:
1042
+ removed += 1
1043
+ if cfg.verbose:
1044
+ print(f" Removed neuron: {neuron_name}")
1045
+ else:
1046
+ for k in keys:
1047
+ weights[k] = saved[k]
1048
+
1049
+ if cfg.verbose:
1050
+ stats = circuit.stats(weights)
1051
+ print(f" Removed {removed} neurons, mag={stats['magnitude']:.0f}")
1052
+
1053
+ return PruneResult(
1054
+ method='neuron',
1055
+ original_stats=original,
1056
+ final_stats=circuit.stats(weights),
1057
+ final_weights=weights,
1058
+ fitness=evaluator.evaluate_single(weights),
1059
+ time_seconds=time.perf_counter() - start,
1060
+ metadata={'neurons_removed': removed}
1061
+ )
1062
+
1063
+
1064
+ def prune_lottery(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
1065
+ cfg: Config) -> PruneResult:
1066
+ """
1067
+ Lottery Ticket Hypothesis pruning.
1068
+
1069
+ Iteratively prune smallest magnitude weights and check if subnetwork works.
1070
+ """
1071
  start = time.perf_counter()
1072
+ original = circuit.stats()
 
1073
 
1074
+ weights = circuit.clone_weights()
1075
+ initial_weights = circuit.clone_weights()
 
1076
 
1077
+ history = []
 
 
 
 
 
1078
 
1079
+ if cfg.verbose:
1080
+ print(f" Lottery ticket: {cfg.lottery_rounds} rounds, {cfg.lottery_prune_rate*100:.0f}% per round")
1081
+
1082
+ for round_num in range(cfg.lottery_rounds):
1083
+ all_weights = []
1084
+ for name, tensor in weights.items():
1085
+ flat = tensor.flatten()
1086
+ for i in range(len(flat)):
1087
+ val = abs(flat[i].item())
1088
+ if val > 0:
1089
+ all_weights.append((val, name, i, tensor.shape))
1090
+
1091
+ if not all_weights:
1092
+ break
1093
 
1094
+ all_weights.sort(key=lambda x: x[0])
1095
+ n_prune = int(len(all_weights) * cfg.lottery_prune_rate)
1096
+
1097
+ if n_prune == 0:
1098
+ break
1099
+
1100
+ to_prune = all_weights[:n_prune]
1101
+
1102
+ mask = {}
1103
+ for name in weights:
1104
+ mask[name] = (weights[name] != 0).float()
1105
+
1106
+ for _, name, idx, shape in to_prune:
1107
+ flat_mask = mask[name].flatten()
1108
+ flat_mask[idx] = 0
1109
+ mask[name] = flat_mask.view(shape)
1110
+
1111
+ for name in weights:
1112
+ weights[name] = initial_weights[name] * mask[name]
1113
+
1114
+ fitness = evaluator.evaluate_single(weights)
1115
+ stats = circuit.stats(weights)
1116
+
1117
+ history.append({
1118
+ 'round': round_num,
1119
+ 'pruned': n_prune,
1120
+ 'remaining': len(all_weights) - n_prune,
1121
+ 'fitness': fitness,
1122
+ 'magnitude': stats['magnitude']
1123
  })
1124
 
1125
  if cfg.verbose:
1126
+ print(f" Round {round_num}: pruned {n_prune}, fitness={fitness:.4f}, mag={stats['magnitude']:.0f}")
1127
+
1128
+ if fitness < cfg.fitness_threshold:
1129
+ for _, name, idx, shape in to_prune:
1130
+ flat_mask = mask[name].flatten()
1131
+ flat_mask[idx] = 1
1132
+ mask[name] = flat_mask.view(shape)
1133
+
1134
+ for name in weights:
1135
+ weights[name] = initial_weights[name] * mask[name]
1136
+
1137
+ if cfg.verbose:
1138
+ print(f" Reverted round {round_num} (fitness dropped)")
1139
+ break
1140
 
1141
  return PruneResult(
1142
+ method='lottery',
1143
  original_stats=original,
1144
+ final_stats=circuit.stats(weights),
1145
  final_weights=weights,
1146
+ fitness=evaluator.evaluate_single(weights),
 
1147
  time_seconds=time.perf_counter() - start,
1148
+ history=history
1149
  )
1150
 
1151
 
1152
+ def prune_topology(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
1153
+ cfg: Config) -> PruneResult:
1154
+ """
1155
+ Topology search - NEAT-style evolution of circuit structure.
 
 
 
 
 
 
 
1156
 
1157
+ This is a simplified version that works with fixed topology but
1158
+ can zero out entire connection patterns.
1159
+ """
1160
+ start = time.perf_counter()
1161
+ original = circuit.stats()
1162
 
1163
+ weights = circuit.clone_weights()
1164
 
1165
+ connection_groups = {}
1166
+ for key in weights.keys():
1167
+ if 'weight' in key:
1168
+ base = key.replace('.weight', '')
1169
+ if base not in connection_groups:
1170
+ connection_groups[base] = {'weight': None, 'bias': None}
1171
+ connection_groups[base]['weight'] = key
1172
+ bias_key = key.replace('weight', 'bias')
1173
+ if bias_key in weights:
1174
+ connection_groups[base]['bias'] = bias_key
1175
 
1176
+ if cfg.verbose:
1177
+ print(f" Found {len(connection_groups)} connection groups")
1178
 
1179
+ active = {k: True for k in connection_groups}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1180
 
1181
+ best_weights = {k: v.clone() for k, v in weights.items()}
1182
+ best_active = dict(active)
1183
+ best_score = -sum(t.abs().sum().item() for t in weights.values())
1184
+
1185
+ for gen in range(cfg.topology_generations):
1186
+ test_active = dict(active)
1187
+
1188
+ if random.random() < cfg.topology_remove_neuron_prob:
1189
+ candidates = [k for k, v in test_active.items() if v]
1190
+ if candidates:
1191
+ to_remove = random.choice(candidates)
1192
+ test_active[to_remove] = False
1193
+
1194
+ if random.random() < cfg.topology_add_neuron_prob:
1195
+ candidates = [k for k, v in test_active.items() if not v]
1196
+ if candidates:
1197
+ to_add = random.choice(candidates)
1198
+ test_active[to_add] = True
1199
+
1200
+ test_weights = {k: v.clone() for k, v in weights.items()}
1201
+ for group_name, is_active in test_active.items():
1202
+ if not is_active:
1203
+ info = connection_groups[group_name]
1204
+ if info['weight']:
1205
+ test_weights[info['weight']] = torch.zeros_like(test_weights[info['weight']])
1206
+ if info['bias']:
1207
+ test_weights[info['bias']] = torch.zeros_like(test_weights[info['bias']])
1208
+
1209
+ fitness = evaluator.evaluate_single(test_weights)
1210
+
1211
+ if fitness >= cfg.fitness_threshold:
1212
+ mag = sum(t.abs().sum().item() for t in test_weights.values())
1213
+ score = -mag
1214
+
1215
+ if score > best_score:
1216
+ best_score = score
1217
+ best_weights = test_weights
1218
+ best_active = dict(test_active)
1219
+ active = test_active
1220
+
1221
+ if cfg.verbose and gen % 50 == 0:
1222
+ n_active = sum(1 for v in best_active.values() if v)
1223
+ stats = circuit.stats(best_weights)
1224
+ print(f" Gen {gen}: {n_active}/{len(connection_groups)} active, mag={stats['magnitude']:.0f}")
1225
+
1226
+ n_removed = sum(1 for v in best_active.values() if not v)
1227
 
1228
+ return PruneResult(
1229
+ method='topology',
1230
+ original_stats=original,
1231
+ final_stats=circuit.stats(best_weights),
1232
+ final_weights=best_weights,
1233
+ fitness=evaluator.evaluate_single(best_weights),
1234
+ time_seconds=time.perf_counter() - start,
1235
+ metadata={'connections_removed': n_removed, 'active_groups': best_active}
1236
+ )
1237
 
1238
 
1239
+ def prune_pareto(circuit: ThresholdCircuit, evaluator: VectorizedEvaluator,
1240
+ cfg: Config) -> PruneResult:
1241
+ """Explore Pareto frontier of correctness vs. size."""
1242
+ start = time.perf_counter()
1243
+ original = circuit.stats()
1244
+ frontier = []
1245
 
1246
+ if cfg.verbose:
1247
+ print(f" Exploring Pareto frontier...")
 
1248
 
1249
+ for target in cfg.pareto_levels:
1250
+ relaxed_cfg = Config(
1251
+ device=cfg.device,
1252
+ fitness_threshold=target,
1253
+ magnitude_passes=30,
1254
+ verbose=False,
1255
+ vram=cfg.vram
1256
+ )
1257
 
1258
+ result = prune_magnitude_vectorized(circuit, evaluator, relaxed_cfg)
 
 
 
 
1259
 
1260
+ frontier.append({
1261
+ 'target': target,
1262
+ 'actual': result.fitness,
1263
+ 'magnitude': result.final_stats['magnitude'],
1264
+ 'nonzero': result.final_stats['nonzero'],
1265
+ 'sparsity': result.final_stats['sparsity']
1266
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1267
 
1268
+ if cfg.verbose:
1269
+ print(f" Target {target:.2f}: fitness={result.fitness:.4f}, mag={result.final_stats['magnitude']:.0f}")
1270
 
1271
+ return PruneResult(
1272
+ method='pareto',
1273
+ original_stats=original,
1274
+ final_stats=frontier[-1] if frontier else original,
1275
+ final_weights=circuit.clone_weights(),
1276
+ fitness=frontier[0]['actual'] if frontier else 1.0,
1277
+ time_seconds=time.perf_counter() - start,
1278
+ history=frontier
1279
+ )
1280
 
 
 
 
1281
 
1282
+ def run_all_methods(circuit: ThresholdCircuit, cfg: Config) -> Dict[str, PruneResult]:
1283
+ """Run all enabled pruning methods."""
1284
 
1285
  print(f"\n{'='*70}")
1286
  print(f" PRUNING: {circuit.spec.name}")
1287
  print(f"{'='*70}")
1288
 
1289
+ vram = get_vram_status()
1290
+ if vram['available']:
1291
+ print(f" VRAM: {vram['total_gb']:.1f} GB total, {vram['free_gb']:.1f} GB free")
1292
+
1293
  original = circuit.stats()
1294
  print(f" Inputs: {circuit.spec.inputs}, Outputs: {circuit.spec.outputs}")
1295
  print(f" Neurons: {circuit.spec.neurons}, Layers: {circuit.spec.layers}")
1296
  print(f" Parameters: {original['total']}, Non-zero: {original['nonzero']}")
1297
  print(f" Magnitude: {original['magnitude']:.0f}")
1298
+ print(f" Test cases: {circuit.n_cases}")
1299
  print(f"{'='*70}")
1300
 
1301
+ evaluator = VectorizedEvaluator(circuit, cfg)
 
 
 
 
 
 
 
 
1302
 
1303
+ initial_fitness = evaluator.evaluate_single(circuit.weights)
1304
+ print(f"\n Initial fitness: {initial_fitness:.6f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
1305
 
1306
+ if initial_fitness < cfg.fitness_threshold:
 
 
 
1307
  print(" ERROR: Circuit doesn't pass baseline!")
1308
  return {}
1309
 
1310
  results = {}
1311
 
1312
+ methods = [
1313
+ ('magnitude', cfg.run_magnitude, lambda: prune_magnitude_vectorized(circuit, evaluator, cfg)),
1314
+ ('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)),
1315
+ ('quantize', cfg.run_quantize, lambda: prune_quantize(circuit, evaluator, cfg)),
1316
+ ('neuron', cfg.run_neuron, lambda: prune_neuron(circuit, evaluator, cfg)),
1317
+ ('lottery', cfg.run_lottery, lambda: prune_lottery(circuit, evaluator, cfg)),
1318
+ ('topology', cfg.run_topology, lambda: prune_topology(circuit, evaluator, cfg)),
1319
+ ('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
1320
+ ('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
1321
+ ('pareto', cfg.run_pareto, lambda: prune_pareto(circuit, evaluator, cfg)),
1322
+ ]
1323
+
1324
+ for i, (name, enabled, fn) in enumerate(methods):
1325
+ if enabled:
1326
+ print(f"\n[{i+1}] {name.upper()}")
1327
+ try:
1328
+ clear_vram()
1329
+ results[name] = fn()
1330
+ _print_result(results[name])
1331
+ except Exception as e:
1332
+ print(f" ERROR: {e}")
1333
+ import traceback
1334
+ traceback.print_exc()
1335
+
 
 
 
 
 
 
 
 
 
 
 
 
 
1336
  print(f"\n{'='*70}")
1337
  print(" SUMMARY")
1338
  print(f"{'='*70}")
1339
+ print(f"\n{'Method':<15} {'Fitness':<10} {'Magnitude':<12} {'Nonzero':<10} {'Sparsity':<10} {'Time':<10}")
1340
  print("-" * 70)
1341
+ print(f"{'Original':<15} {'1.0000':<10} {original['magnitude']:<12.0f} {original['nonzero']:<10} {'0.0%':<10} {'-':<10}")
1342
 
1343
  best_method, best_mag = None, float('inf')
1344
  for name, r in sorted(results.items(), key=lambda x: x[1].final_stats.get('magnitude', float('inf'))):
1345
  mag = r.final_stats.get('magnitude', 0)
1346
  nz = r.final_stats.get('nonzero', 0)
1347
+ sp = r.final_stats.get('sparsity', 0) * 100
1348
+ print(f"{name:<15} {r.fitness:<10.4f} {mag:<12.0f} {nz:<10} {sp:<9.1f}% {r.time_seconds:<10.1f}s")
1349
+
1350
  if r.fitness >= cfg.fitness_threshold and mag < best_mag:
1351
  best_mag = mag
1352
  best_method = name
 
1362
  print(f" Fitness: {r.fitness:.6f}")
1363
  print(f" Magnitude: {r.final_stats.get('magnitude', 0):.0f}")
1364
  print(f" Nonzero: {r.final_stats.get('nonzero', 0)}")
1365
+ print(f" Sparsity: {r.final_stats.get('sparsity', 0)*100:.1f}%")
1366
  print(f" Time: {r.time_seconds:.1f}s")
1367
 
1368
 
1369
+ def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]:
1370
+ """Find all circuits."""
1371
+ circuits = []
1372
+ for d in base.iterdir():
1373
+ if d.is_dir() and (d / 'config.json').exists() and (d / 'model.safetensors').exists():
1374
+ try:
1375
+ with open(d / 'config.json') as f:
1376
+ cfg = json.load(f)
1377
+ circuits.append(CircuitSpec(
1378
+ name=cfg['name'],
1379
+ path=d,
1380
+ inputs=cfg['inputs'],
1381
+ outputs=cfg['outputs'],
1382
+ neurons=cfg['neurons'],
1383
+ layers=cfg['layers'],
1384
+ parameters=cfg['parameters'],
1385
+ description=cfg.get('description', '')
1386
+ ))
1387
+ except:
1388
+ pass
1389
+ return sorted(circuits, key=lambda x: (x.inputs, x.neurons))
1390
+
1391
 
1392
  def main():
1393
+ parser = argparse.ArgumentParser(description='Prune threshold circuits v2')
1394
  parser.add_argument('circuit', nargs='?', help='Circuit name')
1395
+ parser.add_argument('--list', action='store_true')
1396
+ parser.add_argument('--all', action='store_true')
1397
+ parser.add_argument('--max-inputs', type=int, default=10)
1398
+ parser.add_argument('--device', default='cuda')
1399
+ parser.add_argument('--methods', type=str)
 
1400
  parser.add_argument('--fitness', type=float, default=0.9999)
1401
  parser.add_argument('--quiet', action='store_true')
1402
+ parser.add_argument('--save', action='store_true')
1403
+ parser.add_argument('--evo-pop', type=int, default=0)
1404
+ parser.add_argument('--evo-gens', type=int, default=2000)
1405
+ parser.add_argument('--vram-target', type=float, default=0.75)
1406
 
1407
  args = parser.parse_args()
1408
 
 
1410
  specs = discover_circuits()
1411
  print(f"\nAvailable circuits ({len(specs)}):\n")
1412
  for s in specs:
1413
+ print(f" {s.name:<40} {s.inputs}in/{s.outputs}out {s.neurons}N {s.layers}L {s.parameters}P")
1414
  return
1415
 
1416
+ vram_cfg = VRAMConfig(target_residency=args.vram_target)
1417
+
1418
  cfg = Config(
1419
  device=args.device,
 
1420
  fitness_threshold=args.fitness,
1421
+ verbose=not args.quiet,
1422
+ vram=vram_cfg,
1423
+ evo_pop_size=args.evo_pop,
1424
+ evo_generations=args.evo_gens
1425
  )
1426
 
1427
  if args.methods:
1428
  methods = args.methods.lower().split(',')
1429
+ cfg.run_magnitude = 'mag' in methods or 'magnitude' in methods
 
1430
  cfg.run_zero = 'zero' in methods
1431
+ cfg.run_quantize = 'quant' in methods or 'quantize' in methods
1432
  cfg.run_evolutionary = 'evo' in methods or 'evolutionary' in methods
1433
  cfg.run_annealing = 'anneal' in methods or 'sa' in methods
1434
+ cfg.run_neuron = 'neuron' in methods
1435
+ cfg.run_lottery = 'lottery' in methods
1436
+ cfg.run_topology = 'topology' in methods or 'topo' in methods
1437
  cfg.run_pareto = 'pareto' in methods
1438
 
1439
  RESULTS_PATH.mkdir(exist_ok=True)
 
1443
  print(f"\nRunning on {len(specs)} circuits...")
1444
  for spec in specs:
1445
  try:
1446
+ circuit = ThresholdCircuit(spec.path, cfg.device)
1447
  results = run_all_methods(circuit, cfg)
1448
+ clear_vram()
1449
  except Exception as e:
1450
  print(f"ERROR on {spec.name}: {e}")
1451
  elif args.circuit:
1452
+ path = CIRCUITS_PATH / args.circuit
1453
+ if not path.exists():
1454
+ path = CIRCUITS_PATH / f'threshold-{args.circuit}'
1455
+ if not path.exists():
1456
+ print(f"Circuit not found: {args.circuit}")
1457
+ return
1458
+
1459
+ circuit = ThresholdCircuit(path, cfg.device)
1460
  results = run_all_methods(circuit, cfg)
1461
 
1462
  if args.save and results:
1463
  best = min(results.values(), key=lambda r: r.final_stats.get('magnitude', float('inf')))
1464
  if best.fitness >= cfg.fitness_threshold:
1465
+ path = circuit.save_weights(best.final_weights, f'pruned_{best.method}')
1466
  print(f"\nSaved to: {path}")
1467
  else:
1468
  parser.print_help()
1469
  print("\n\nExamples:")
1470
+ print(" python prune_v2.py --list")
1471
+ print(" python prune_v2.py threshold-hamming74decoder")
1472
+ print(" python prune_v2.py threshold-hamming74decoder --methods evo,neuron,lottery")
1473
+ print(" python prune_v2.py --all --max-inputs 8")
1474
 
1475
 
1476
  if __name__ == '__main__':