threshold-pruner / prune.py
CharlesCNorton
Refactor pruner v5: streamline from 16 to 6 core methods
5a44526
"""
Threshold Circuit Pruner v5 (Refactored)
Streamlined pruning framework with 6 core methods:
- magnitude: Greedy weight reduction
- zero: Try zeroing individual weights
- evolutionary: GPU-parallel genetic algorithm
- exhaustive_mag: Provably optimal for small circuits
- architecture: Search flat 2-layer alternatives
- compositional: For circuits built from known-optimal components
Usage:
python prune.py threshold-xor --methods evo
python prune.py threshold-xor --methods exh_mag
python prune.py threshold-crc16-mag53 --methods comp
python prune.py --list
"""
import torch
import torch.nn.functional as F
import json
import time
import random
import argparse
import math
import gc
import re
import importlib.util
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Callable
from safetensors.torch import load_file, save_file
from collections import defaultdict
from functools import lru_cache
from itertools import combinations, product
import warnings
warnings.filterwarnings('ignore')
CIRCUITS_PATH = Path('D:/threshold-logic-circuits')
RESULTS_PATH = CIRCUITS_PATH / 'pruned_results'
@dataclass
class VRAMConfig:
target_residency: float = 0.75
safety_margin: float = 0.05
def __post_init__(self):
self.total_bytes = 0
self.device_name = "CPU"
if torch.cuda.is_available():
props = torch.cuda.get_device_properties(0)
self.total_bytes = props.total_memory
self.device_name = props.name
@property
def total_gb(self) -> float:
return self.total_bytes / 1e9
@property
def available_bytes(self) -> int:
return int(self.total_bytes * (self.target_residency - self.safety_margin))
def current_usage(self) -> Dict:
if not torch.cuda.is_available():
return {'allocated_gb': 0, 'free_gb': 0, 'utilization': 0}
allocated = torch.cuda.memory_allocated()
return {
'allocated_gb': allocated / 1e9,
'free_gb': (self.total_bytes - allocated) / 1e9,
'utilization': allocated / self.total_bytes if self.total_bytes > 0 else 0
}
def clear_vram():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
@dataclass
class Config:
device: str = 'cuda'
fitness_threshold: float = 0.9999
verbose: bool = True
vram: VRAMConfig = field(default_factory=VRAMConfig)
run_magnitude: bool = False
run_zero: bool = False
run_evolutionary: bool = False
run_exhaustive_mag: bool = False
run_architecture: bool = False
run_compositional: bool = False
magnitude_passes: int = 100
exhaustive_max_params: int = 12
exhaustive_target_mag: int = -1
evo_generations: int = 2000
evo_pop_size: int = 0
evo_elite_ratio: float = 0.05
evo_mutation_rate: float = 0.15
evo_mutation_strength: float = 2.0
evo_crossover_rate: float = 0.3
evo_parsimony: float = 0.001
arch_hidden_neurons: int = 3
arch_max_weight: int = 3
arch_max_mag: int = 20
@dataclass
class CircuitSpec:
name: str
path: Path
inputs: int
outputs: int
neurons: int
layers: int
parameters: int
description: str = ""
@dataclass
class PruneResult:
method: str
original_stats: Dict
final_stats: Dict
final_weights: Dict[str, torch.Tensor]
fitness: float
time_seconds: float
metadata: Dict = field(default_factory=dict)
class ComputationGraph:
"""Parses weight structure to build dependency graph."""
def __init__(self, weights: Dict[str, torch.Tensor], n_inputs: int, n_outputs: int, device: str):
self.device = device
self.n_inputs = n_inputs
self.n_outputs = n_outputs
self.weights = weights
self.neurons = {}
self.neuron_order = []
self.output_neurons = []
self.layer_groups = defaultdict(list)
self._parse_structure()
self._build_execution_order()
def _parse_structure(self):
neuron_weights = defaultdict(dict)
for key, tensor in self.weights.items():
if '.weight' in key:
neuron_name = key.replace('.weight', '')
neuron_weights[neuron_name]['weight'] = key
neuron_weights[neuron_name]['weight_shape'] = tensor.shape
elif '.bias' in key:
neuron_name = key.replace('.bias', '')
neuron_weights[neuron_name]['bias'] = key
for neuron_name, params in neuron_weights.items():
if 'weight' in params:
depth = self._estimate_depth(neuron_name)
self.neurons[neuron_name] = {
'weight_key': params.get('weight'),
'bias_key': params.get('bias'),
'weight_shape': params.get('weight_shape', (1,)),
'input_size': params.get('weight_shape', (1,))[-1],
'depth': depth
}
self.layer_groups[depth].append(neuron_name)
self._infer_depth_from_shapes()
self._identify_outputs()
def _estimate_depth(self, name: str) -> int:
depth = 0
if 'layer' in name:
match = re.search(r'layer(\d+)', name)
if match:
depth = int(match.group(1))
depth += len(name.split('.')) - 1
return depth
def _infer_depth_from_shapes(self):
for name, info in self.neurons.items():
if info.get('input_size', 0) == self.n_inputs:
info['depth'] = 0
info['input_source'] = 'raw'
self.layer_groups = defaultdict(list)
for name, info in self.neurons.items():
self.layer_groups[info['depth']].append(name)
def _identify_outputs(self):
candidates = []
for name in self.neurons:
is_parent = any(name + '.' in other for other in self.neurons if other != name)
if not is_parent:
candidates.append((name, self.neurons[name]['depth']))
candidates.sort(key=lambda x: (-x[1], x[0]))
self.output_neurons = [c[0] for c in candidates[:self.n_outputs]]
def _build_execution_order(self):
self.neuron_order = sorted(self.neurons.keys(), key=lambda n: (self.neurons[n]['depth'], n))
def forward_single(self, inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
activations = {'input': inputs}
for neuron_name in self.neuron_order:
info = self.neurons[neuron_name]
w_key, b_key = info['weight_key'], info['bias_key']
if w_key and w_key in weights:
w = weights[w_key]
if w.dim() == 1:
w = w.unsqueeze(0)
inp = self._get_neuron_input(neuron_name, activations, inputs, w.shape[-1])
if inp.dim() == 1:
out = inp @ w.flatten()
else:
out = (inp.unsqueeze(-2) @ w.T.unsqueeze(0)).squeeze(-2)
if out.dim() > 1 and out.shape[-1] == 1:
out = out.squeeze(-1)
if b_key and b_key in weights:
out = out + weights[b_key].squeeze()
activations[neuron_name] = (out >= 0).float()
return self._collect_outputs(activations)
def _get_neuron_input(self, neuron_name: str, activations: Dict, raw_input: torch.Tensor, expected_size: int) -> torch.Tensor:
info = self.neurons.get(neuron_name, {})
if info.get('input_source') == 'raw' or expected_size == self.n_inputs:
return raw_input
if 'layer2' in neuron_name or '.out' in neuron_name:
base = neuron_name.replace('.layer2', '').replace('.out', '')
hidden_keys = [k for k in activations if k.startswith(base) and k != neuron_name and k != 'input']
if len(hidden_keys) == expected_size:
return torch.stack([activations[k] for k in sorted(hidden_keys)], dim=-1)
return raw_input[..., :expected_size] if raw_input.shape[-1] >= expected_size else raw_input
def _collect_outputs(self, activations: Dict) -> torch.Tensor:
outputs = [activations[n] for n in sorted(self.output_neurons) if n in activations]
if outputs:
return torch.stack(outputs, dim=-1) if outputs[0].dim() > 0 else torch.stack(outputs)
return torch.zeros(self.n_outputs, device=self.device)
class AdaptiveCircuit:
"""Adaptive threshold circuit with automatic evaluation."""
def __init__(self, path: Path, device: str = 'cuda', weights_file: str = None):
self.path = Path(path)
self.device = device
self.spec = self._load_spec()
self.weights = self._load_weights(weights_file)
self.weight_keys = list(self.weights.keys())
self.n_weights = sum(t.numel() for t in self.weights.values())
self.native_forward = self._try_load_native_forward()
self.has_native = self.native_forward is not None
print(f" [LOAD] Native forward: {'FOUND' if self.has_native else 'NOT FOUND'}")
print(f" [LOAD] Parsing circuit topology...")
self.graph = ComputationGraph(self.weights, self.spec.inputs, self.spec.outputs, device)
print(f" [LOAD] Found {len(self.graph.neurons)} neurons across {len(self.graph.layer_groups)} layers")
print(f" [LOAD] Output neurons: {sorted(self.graph.output_neurons)}")
print(f" [LOAD] Building test cases...")
self.test_inputs, self.test_expected = self._build_tests()
self.n_cases = self.test_inputs.shape[0]
print(f" [LOAD] Generated {self.n_cases} test cases")
self._compile_fast_forward()
print(f" [LOAD] Circuit ready: {self.n_weights} weight parameters")
def _try_load_native_forward(self) -> Optional[Callable]:
model_py = self.path / 'model.py'
if not model_py.exists():
return None
try:
spec = importlib.util.spec_from_file_location("circuit_model", model_py)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if hasattr(module, 'forward'):
return module.forward
return None
except Exception:
return None
def evaluate_native(self, inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
if not self.has_native:
return self.graph.forward_single(inputs, weights)
try:
out = self.native_forward(inputs, weights)
if isinstance(out, torch.Tensor):
return out.to(self.device) if out.dim() > 1 else out.unsqueeze(-1).to(self.device)
except Exception:
pass
cpu_weights = {k: v.cpu() for k, v in weights.items()}
results = []
for i in range(inputs.shape[0]):
inp = [int(x) for x in inputs[i].cpu().tolist()]
try:
out = self.native_forward(inp, cpu_weights)
results.append([float(x) for x in out] if isinstance(out, (list, tuple)) else [float(out)])
except Exception:
results.append([0.0] * self.spec.outputs)
return torch.tensor(results, device=self.device, dtype=torch.float32)
def _load_spec(self) -> CircuitSpec:
with open(self.path / 'config.json') as f:
cfg = json.load(f)
return CircuitSpec(
name=cfg.get('name', self.path.name),
path=self.path,
inputs=cfg.get('inputs', cfg.get('input_size', 0)),
outputs=cfg.get('outputs', cfg.get('output_size', 0)),
neurons=cfg.get('neurons', 0),
layers=cfg.get('layers', 0),
parameters=cfg.get('parameters', 0),
description=cfg.get('description', '')
)
def _load_weights(self, weights_file: str = None) -> Dict[str, torch.Tensor]:
if weights_file:
sf = self.path / weights_file
else:
sf = self.path / 'model.safetensors'
if not sf.exists():
candidates = list(self.path.glob('*.safetensors'))
sf = candidates[0] if candidates else sf
w = load_file(str(sf))
return {k: v.float().to(self.device) for k, v in w.items()}
def _build_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
if self.has_native:
return self._build_native_tests()
return self._build_exhaustive_tests()
def _build_exhaustive_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
n = self.spec.inputs
if n > 24:
raise ValueError(f"Input space too large: 2^{n}")
n_cases = 2 ** n
idx = torch.arange(n_cases, device=self.device, dtype=torch.long)
bits = torch.arange(n, device=self.device, dtype=torch.long)
inputs = ((idx.unsqueeze(1) >> bits) & 1).float()
expected = self.graph.forward_single(inputs, self.weights)
return inputs, expected
def _build_native_tests(self) -> Tuple[torch.Tensor, torch.Tensor]:
n = self.spec.inputs
if n > 20:
raise ValueError(f"Input space too large: 2^{n}")
n_cases = 2 ** n
inputs_list, expected_list = [], []
cpu_weights = {k: v.cpu() for k, v in self.weights.items()}
for i in range(n_cases):
inp = [(i >> b) & 1 for b in range(n)]
inputs_list.append(inp)
out = self.native_forward(inp, cpu_weights)
expected_list.append([float(x) for x in out] if isinstance(out, (list, tuple)) else [float(out)])
return (torch.tensor(inputs_list, device=self.device, dtype=torch.float32),
torch.tensor(expected_list, device=self.device, dtype=torch.float32))
def _compile_fast_forward(self):
self.weight_layout = []
offset = 0
for key in self.weight_keys:
size = self.weights[key].numel()
self.weight_layout.append((key, offset, offset + size, self.weights[key].shape))
offset += size
self.base_vector = self.weights_to_vector(self.weights)
def weights_to_vector(self, weights: Dict[str, torch.Tensor]) -> torch.Tensor:
return torch.cat([weights[k].flatten() for k in self.weight_keys])
def vector_to_weights(self, vector: torch.Tensor) -> Dict[str, torch.Tensor]:
weights = {}
for key, start, end, shape in self.weight_layout:
weights[key] = vector[start:end].view(shape)
return weights
def clone_weights(self) -> Dict[str, torch.Tensor]:
return {k: v.clone() for k, v in self.weights.items()}
def stats(self, weights: Dict[str, torch.Tensor] = None) -> Dict:
w = weights or self.weights
total = sum(t.numel() for t in w.values())
nonzero = sum((t != 0).sum().item() for t in w.values())
mag = sum(t.abs().sum().item() for t in w.values())
maxw = max(t.abs().max().item() for t in w.values()) if w else 0
return {
'total': total,
'nonzero': nonzero,
'sparsity': 1 - nonzero / total if total else 0,
'magnitude': mag,
'max_weight': maxw
}
def save_weights(self, weights: Dict[str, torch.Tensor], suffix: str = 'pruned') -> Path:
path = self.path / f'model_{suffix}.safetensors'
save_file({k: v.cpu() for k, v in weights.items()}, str(path))
return path
class BatchedEvaluator:
"""GPU-optimized batched population evaluation."""
def __init__(self, circuit: AdaptiveCircuit, cfg: Config):
self.circuit = circuit
self.cfg = cfg
self.device = cfg.device
self.test_inputs = circuit.test_inputs
self.test_expected = circuit.test_expected
self.n_cases = circuit.n_cases
self.n_weights = circuit.n_weights
if cfg.verbose:
print(f" [EVAL] Initializing evaluator...")
self._calculate_batch_size()
self._validate_evaluation()
if cfg.verbose:
print(f" [EVAL] Evaluator ready: batch={self.max_batch:,}, native={circuit.has_native}")
def _calculate_batch_size(self):
bytes_per_ind = self.n_weights * 4 * 2 + self.n_cases * self.circuit.spec.outputs * 4 + 4096
available = self.cfg.vram.available_bytes
self.max_batch = max(1000, min(available // max(bytes_per_ind, 1), 5_000_000))
def _validate_evaluation(self):
fitness = self.evaluate_single(self.circuit.weights)
if fitness < 0.999 and self.cfg.verbose:
print(f" [EVAL WARNING] Original weights fitness={fitness:.4f}")
def evaluate_single(self, weights: Dict[str, torch.Tensor]) -> float:
with torch.no_grad():
if self.circuit.has_native:
outputs = self.circuit.evaluate_native(self.test_inputs, weights)
else:
outputs = self.circuit.graph.forward_single(self.test_inputs, weights)
if outputs.shape != self.test_expected.shape:
if outputs.dim() == 1:
outputs = outputs.unsqueeze(0).expand(self.test_expected.shape[0], -1)
correct = (outputs == self.test_expected).all(dim=-1).float().sum()
return (correct / self.n_cases).item()
def evaluate_population(self, population: torch.Tensor) -> torch.Tensor:
pop_size = population.shape[0]
if pop_size > self.max_batch:
return self._evaluate_chunked(population)
return self._evaluate_sequential(population)
def _evaluate_sequential(self, population: torch.Tensor) -> torch.Tensor:
pop_size = population.shape[0]
fitness = torch.zeros(pop_size, device=self.device)
with torch.no_grad():
for i in range(pop_size):
weights = self.circuit.vector_to_weights(population[i])
if self.circuit.has_native:
outputs = self.circuit.evaluate_native(self.test_inputs, weights)
else:
outputs = self.circuit.graph.forward_single(self.test_inputs, weights)
if outputs.shape != self.test_expected.shape:
if outputs.dim() == 1:
outputs = outputs.unsqueeze(0).expand(self.test_expected.shape[0], -1)
if outputs.shape == self.test_expected.shape:
correct = (outputs == self.test_expected).all(dim=-1).float().sum()
fitness[i] = correct / self.n_cases
return fitness
def _evaluate_chunked(self, population: torch.Tensor) -> torch.Tensor:
pop_size = population.shape[0]
fitness = torch.zeros(pop_size, device=self.device)
for start in range(0, pop_size, self.max_batch):
end = min(start + self.max_batch, pop_size)
fitness[start:end] = self._evaluate_sequential(population[start:end])
return fitness
# =============================================================================
# PRUNING METHODS
# =============================================================================
def prune_magnitude(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
"""Iterative magnitude reduction - try reducing each weight toward zero."""
start = time.perf_counter()
weights = circuit.clone_weights()
original = circuit.stats(weights)
if cfg.verbose:
print(f" Starting magnitude reduction (passes={cfg.magnitude_passes})...")
for pass_num in range(cfg.magnitude_passes):
candidates = []
for name, tensor in weights.items():
flat = tensor.flatten()
for i in range(len(flat)):
val = flat[i].item()
if val != 0:
new_val = val - 1 if val > 0 else val + 1
candidates.append((name, i, tensor.shape, val, new_val))
if not candidates:
break
random.shuffle(candidates)
reductions = 0
for name, idx, shape, old_val, new_val in candidates:
flat = weights[name].flatten()
flat[idx] = new_val
weights[name] = flat.view(shape)
if evaluator.evaluate_single(weights) >= cfg.fitness_threshold:
reductions += 1
else:
flat[idx] = old_val
weights[name] = flat.view(shape)
if cfg.verbose:
stats = circuit.stats(weights)
print(f" Pass {pass_num}: {reductions} reductions, mag={stats['magnitude']:.0f}")
if reductions == 0:
break
return PruneResult(
method='magnitude',
original_stats=original,
final_stats=circuit.stats(weights),
final_weights=weights,
fitness=evaluator.evaluate_single(weights),
time_seconds=time.perf_counter() - start
)
def prune_zero(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
"""Zero pruning - try setting each non-zero weight to zero."""
start = time.perf_counter()
weights = circuit.clone_weights()
original = circuit.stats(weights)
candidates = []
for name, tensor in weights.items():
flat = tensor.flatten()
for i in range(len(flat)):
if flat[i].item() != 0:
candidates.append((name, i, tensor.shape, flat[i].item()))
random.shuffle(candidates)
if cfg.verbose:
print(f" Testing {len(candidates)} non-zero weights for zeroing...")
zeroed = 0
for name, idx, shape, old_val in candidates:
flat = weights[name].flatten()
flat[idx] = 0
weights[name] = flat.view(shape)
if evaluator.evaluate_single(weights) >= cfg.fitness_threshold:
zeroed += 1
else:
flat[idx] = old_val
weights[name] = flat.view(shape)
if cfg.verbose:
print(f" Zeroed {zeroed}/{len(candidates)} weights")
return PruneResult(
method='zero',
original_stats=original,
final_stats=circuit.stats(weights),
final_weights=weights,
fitness=evaluator.evaluate_single(weights),
time_seconds=time.perf_counter() - start
)
def prune_evolutionary(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
"""Evolutionary search with GPU-optimized parallel population evaluation."""
start = time.perf_counter()
original = circuit.stats()
if cfg.evo_pop_size > 0:
pop_size = cfg.evo_pop_size
else:
bytes_per_ind = circuit.n_weights * 4 * 3
available = cfg.vram.available_bytes - torch.cuda.memory_allocated() if torch.cuda.is_available() else cfg.vram.available_bytes
pop_size = max(10000, min(available // max(bytes_per_ind, 1), evaluator.max_batch, 500000))
elite_size = max(1, int(pop_size * cfg.evo_elite_ratio))
if cfg.verbose:
print(f" [EVO] Population: {pop_size:,}, Elite: {elite_size:,}, Generations: {cfg.evo_generations}")
base_vector = circuit.weights_to_vector(circuit.weights)
population = base_vector.unsqueeze(0).expand(pop_size, -1).clone()
# Initialize with varied mutations
n_exact = max(elite_size, pop_size // 10)
for i in range(n_exact, pop_size // 2):
n_muts = max(1, circuit.n_weights // 10)
mut_idx = torch.randperm(circuit.n_weights)[:n_muts]
population[i, mut_idx] += torch.randint(-1, 2, (n_muts,), device=cfg.device, dtype=population.dtype)
for i in range(pop_size // 2, pop_size):
n_muts = max(1, circuit.n_weights // 4)
mut_idx = torch.randperm(circuit.n_weights)[:n_muts]
population[i, mut_idx] += torch.randint(-2, 3, (n_muts,), device=cfg.device, dtype=population.dtype)
best_weights = circuit.clone_weights()
best_fitness = evaluator.evaluate_single(best_weights)
best_mag = original['magnitude']
best_score = best_fitness - cfg.evo_parsimony * best_mag / circuit.n_weights if best_fitness >= cfg.fitness_threshold else -float('inf')
stagnant = 0
mutation_rate = cfg.evo_mutation_rate
for gen in range(cfg.evo_generations):
fitness = evaluator.evaluate_population(population)
magnitudes = population.abs().sum(dim=1)
adjusted = fitness - cfg.evo_parsimony * magnitudes / circuit.n_weights
valid_mask = fitness >= cfg.fitness_threshold
n_valid = valid_mask.sum().item()
if n_valid > 0:
valid_adjusted = adjusted.clone()
valid_adjusted[~valid_mask] = -float('inf')
best_idx = valid_adjusted.argmax().item()
if adjusted[best_idx] > best_score:
best_score = adjusted[best_idx].item()
best_fitness = fitness[best_idx].item()
best_weights = circuit.vector_to_weights(population[best_idx].clone())
best_mag = magnitudes[best_idx].item()
stagnant = 0
else:
stagnant += 1
else:
stagnant += 1
# Adaptive mutation
if stagnant > 50:
mutation_rate = min(0.5, mutation_rate * 1.1)
elif stagnant == 0:
mutation_rate = max(0.01, mutation_rate * 0.95)
if cfg.verbose and (gen % 50 == 0 or gen == cfg.evo_generations - 1):
print(f" Gen {gen:4d} | valid: {n_valid:6,}/{pop_size:,} | mag: {best_mag:.0f} | stag: {stagnant}")
# Selection and reproduction
sorted_idx = adjusted.argsort(descending=True)
elite = population[sorted_idx[:elite_size]].clone()
probs = F.softmax(adjusted * 10, dim=0)
parent_idx = torch.multinomial(probs, pop_size - elite_size, replacement=True)
children = population[parent_idx].clone()
# Crossover
if cfg.evo_crossover_rate > 0:
cross_mask = torch.rand(len(children), device=cfg.device) < cfg.evo_crossover_rate
cross_idx = torch.where(cross_mask)[0]
for i in range(0, len(cross_idx) - 1, 2):
p1, p2 = cross_idx[i].item(), cross_idx[i + 1].item()
point = random.randint(1, circuit.n_weights - 1)
temp = children[p1, point:].clone()
children[p1, point:] = children[p2, point:]
children[p2, point:] = temp
# Mutation
mut_mask = torch.rand_like(children) < mutation_rate
mutations = torch.randint(-int(cfg.evo_mutation_strength), int(cfg.evo_mutation_strength) + 1,
children.shape, device=cfg.device, dtype=children.dtype)
children = children + mut_mask * mutations
population = torch.cat([elite, children], dim=0)
if stagnant > 200:
if cfg.verbose:
print(f" [EVO] Early stop at generation {gen}")
break
if cfg.verbose:
final_stats = circuit.stats(best_weights)
print(f" [EVO] Final: mag={final_stats['magnitude']:.0f} (was {original['magnitude']:.0f})")
return PruneResult(
method='evolutionary',
original_stats=original,
final_stats=circuit.stats(best_weights),
final_weights=best_weights,
fitness=best_fitness,
time_seconds=time.perf_counter() - start,
metadata={'generations': gen + 1, 'population_size': pop_size}
)
def _partitions(total: int, n: int, max_val: int):
"""Generate all ways to partition 'total' into 'n' non-negative integers <= max_val."""
if n == 0:
if total == 0:
yield []
return
for i in range(min(total, max_val) + 1):
for rest in _partitions(total - i, n - 1, max_val):
yield [i] + rest
def _all_signs(abs_vals: list):
"""Generate all sign combinations for absolute values."""
if not abs_vals:
yield []
return
for rest in _all_signs(abs_vals[1:]):
if abs_vals[0] == 0:
yield [0] + rest
else:
yield [abs_vals[0]] + rest
yield [-abs_vals[0]] + rest
def _configs_at_magnitude(mag: int, n_params: int):
"""Generate all n_params-length configs with given total magnitude."""
for partition in _partitions(mag, n_params, mag):
for signed in _all_signs(partition):
yield tuple(signed)
def prune_exhaustive_mag(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
"""Exhaustive search by magnitude - finds provably optimal solutions."""
start = time.perf_counter()
original = circuit.stats()
n_params = original['total']
max_mag = int(original['magnitude'])
target_mag = cfg.exhaustive_target_mag
if n_params > cfg.exhaustive_max_params:
if cfg.verbose:
print(f" [EXHAUSTIVE] Skipping: {n_params} params > max {cfg.exhaustive_max_params}")
return PruneResult(
method='exhaustive_mag',
original_stats=original,
final_stats=original,
final_weights=circuit.clone_weights(),
fitness=evaluator.evaluate_single(circuit.weights),
time_seconds=time.perf_counter() - start,
metadata={'skipped': True}
)
if cfg.verbose:
print(f" [EXHAUSTIVE] Parameters: {n_params}, Original magnitude: {max_mag}")
if target_mag >= 0:
print(f" [EXHAUSTIVE] Target magnitude: {target_mag}")
weight_keys = list(circuit.weights.keys())
weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
def vector_to_weights(vec):
weights = {}
idx = 0
for k in weight_keys:
size = weight_sizes[k]
weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k])
idx += size
return weights
all_solutions = []
optimal_mag = None
total_tested = 0
mag_range = [target_mag] if target_mag >= 0 else range(0, max_mag + 1)
for mag in mag_range:
configs = list(_configs_at_magnitude(mag, n_params))
if not configs:
continue
if cfg.verbose:
print(f" Magnitude {mag}: {len(configs):,} configurations...", end=" ", flush=True)
valid = []
batch_size = min(100000, len(configs))
for batch_start in range(0, len(configs), batch_size):
batch = configs[batch_start:batch_start + batch_size]
population = torch.tensor(batch, dtype=torch.float32, device=cfg.device)
try:
fitness = evaluator.evaluate_population(population)
except:
fitness = torch.tensor([evaluator.evaluate_single(vector_to_weights(c)) for c in batch], device=cfg.device)
for i, is_valid in enumerate((fitness >= cfg.fitness_threshold).tolist()):
if is_valid:
valid.append(batch[i])
total_tested += len(configs)
if valid:
if cfg.verbose:
print(f"FOUND {len(valid)} solutions!")
optimal_mag = mag
all_solutions = valid
if cfg.verbose and len(valid) <= 50:
print(f" Solutions:")
for i, sol in enumerate(valid[:20]):
nz = sum(1 for v in sol if v != 0)
print(f" {i+1}: mag={sum(abs(v) for v in sol)}, nz={nz}, {sol}")
break
else:
if cfg.verbose:
print("none")
if all_solutions:
best_weights = vector_to_weights(all_solutions[0])
best_fitness = evaluator.evaluate_single(best_weights)
else:
best_weights = circuit.clone_weights()
best_fitness = evaluator.evaluate_single(best_weights)
optimal_mag = max_mag
if cfg.verbose:
print(f" [EXHAUSTIVE] Tested: {total_tested:,}, Optimal: {optimal_mag}, Solutions: {len(all_solutions)}")
return PruneResult(
method='exhaustive_mag',
original_stats=original,
final_stats=circuit.stats(best_weights),
final_weights=best_weights,
fitness=best_fitness,
time_seconds=time.perf_counter() - start,
metadata={'optimal_magnitude': optimal_mag, 'solutions_count': len(all_solutions), 'all_solutions': all_solutions[:100]}
)
def prune_architecture(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
"""Architecture search - find optimal flat 2-layer architecture."""
start = time.perf_counter()
original = circuit.stats()
n_hidden = cfg.arch_hidden_neurons
n_inputs = circuit.spec.inputs
n_outputs = circuit.spec.outputs
max_weight = cfg.arch_max_weight
max_mag = cfg.arch_max_mag
n_params = n_hidden * (n_inputs + 1) + n_outputs * (n_hidden + 1)
if cfg.verbose:
print(f" [ARCH] Hidden: {n_hidden}, Params: {n_params}, Max magnitude: {max_mag}")
test_inputs = circuit.test_inputs
test_expected = circuit.test_expected
def eval_flat(configs: torch.Tensor) -> torch.Tensor:
batch_size = configs.shape[0]
idx = 0
hidden_w, hidden_b = [], []
for _ in range(n_hidden):
hidden_w.append(configs[:, idx:idx+n_inputs])
idx += n_inputs
hidden_b.append(configs[:, idx:idx+1])
idx += 1
output_w, output_b = [], []
for _ in range(n_outputs):
output_w.append(configs[:, idx:idx+n_hidden])
idx += n_hidden
output_b.append(configs[:, idx:idx+1])
idx += 1
hidden_acts = []
for h in range(n_hidden):
act = (hidden_w[h].unsqueeze(1) * test_inputs.unsqueeze(0)).sum(dim=2) + hidden_b[h]
hidden_acts.append((act >= 0).float())
hidden_stack = torch.stack(hidden_acts, dim=2)
outputs = []
for o in range(n_outputs):
out = (hidden_stack * output_w[o].unsqueeze(1)).sum(dim=2) + output_b[o]
outputs.append((out >= 0).float())
if n_outputs == 1:
predicted = outputs[0]
expected = test_expected.squeeze()
else:
predicted = torch.stack(outputs, dim=2)
expected = test_expected
correct = (predicted == expected.unsqueeze(0)).float().mean(dim=1)
if n_outputs > 1:
correct = correct.mean(dim=1)
return correct
@lru_cache(maxsize=None)
def partitions(total: int, n_slots: int, max_val: int) -> list:
if n_slots == 0:
return [()] if total == 0 else []
if n_slots == 1:
return [(total,)] if total <= max_val else []
result = []
for v in range(min(total, max_val) + 1):
for rest in partitions(total - v, n_slots - 1, max_val):
result.append((v,) + rest)
return result
def signs_for_partition(partition: tuple) -> torch.Tensor:
n = len(partition)
nonzero_idx = [i for i, v in enumerate(partition) if v != 0]
k = len(nonzero_idx)
if k == 0:
return torch.zeros(1, n, device=cfg.device, dtype=torch.float32)
n_patterns = 2 ** k
configs = torch.zeros(n_patterns, n, device=cfg.device, dtype=torch.float32)
for i, idx in enumerate(nonzero_idx):
signs = ((torch.arange(n_patterns, device=cfg.device) >> i) & 1) * 2 - 1
configs[:, idx] = signs.float() * partition[idx]
return configs
all_solutions = []
optimal_mag = None
total_tested = 0
for target_mag in range(1, max_mag + 1):
all_configs = []
for partition in partitions(target_mag, n_params, max_weight):
all_configs.append(signs_for_partition(partition))
if not all_configs:
continue
configs = torch.cat(all_configs, dim=0)
if cfg.verbose:
print(f" Magnitude {target_mag}: {configs.shape[0]:,} configs...", end=" ", flush=True)
valid = []
for i in range(0, configs.shape[0], 500000):
batch = configs[i:i+500000]
fitness = eval_flat(batch)
valid.extend(batch[fitness >= cfg.fitness_threshold].cpu().tolist())
total_tested += configs.shape[0]
if valid:
if cfg.verbose:
print(f"FOUND {len(valid)} solutions!")
optimal_mag = target_mag
all_solutions = valid
break
else:
if cfg.verbose:
print("none")
if cfg.verbose:
print(f" [ARCH] Tested: {total_tested:,}, Optimal: {optimal_mag}, Solutions: {len(all_solutions)}")
return PruneResult(
method='architecture',
original_stats=original,
final_stats=original,
final_weights=circuit.clone_weights(),
fitness=evaluator.evaluate_single(circuit.weights),
time_seconds=time.perf_counter() - start,
metadata={'optimal_magnitude': optimal_mag, 'solutions_count': len(all_solutions)}
)
# =============================================================================
# COMPOSITIONAL SEARCH - For circuits built from known-optimal components
# =============================================================================
# Known optimal component families
OPTIMAL_COMPONENTS = {
'xor': {
'inputs': 2,
'outputs': 1,
'neurons': 3,
'magnitude': 7,
'solutions': [
# Each solution: [(w1, b1), (w2, b2), (out_w, out_b)]
# h1, h2 read raw inputs; out reads h1, h2
{'h1': ([-1, 1], 0), 'h2': ([1, -1], 0), 'out': ([-1, -1], 1)},
{'h1': ([1, -1], 0), 'h2': ([-1, 1], 0), 'out': ([-1, -1], 1)},
{'h1': ([-1, 1], 0), 'h2': ([-1, 1], 0), 'out': ([1, -1], 0)},
{'h1': ([1, -1], 0), 'h2': ([1, -1], 0), 'out': ([-1, 1], 0)},
{'h1': ([1, 1], -1), 'h2': ([-1, -1], 1), 'out': ([1, 1], -1)},
{'h1': ([-1, -1], 1), 'h2': ([1, 1], -1), 'out': ([1, 1], -1)},
]
},
'xor3': {
'inputs': 3,
'outputs': 1,
'neurons': 4,
'magnitude': 10,
'solutions': [
# 18 known solutions at magnitude 10
{'h1': ([0, 0, -1], 0), 'h2': ([-1, 1, -1], 0), 'h3': ([-1, -1, 1], 0), 'out': ([1, -1, -1], 0)},
{'h1': ([0, 0, 1], -1), 'h2': ([1, -1, 1], -1), 'h3': ([1, 1, -1], -1), 'out': ([-1, 1, 1], 0)},
{'h1': ([0, -1, 0], 0), 'h2': ([-1, -1, 1], 0), 'h3': ([1, -1, -1], 0), 'out': ([-1, -1, 1], 0)},
{'h1': ([0, 1, 0], -1), 'h2': ([1, 1, -1], -1), 'h3': ([-1, 1, 1], -1), 'out': ([1, 1, -1], 0)},
{'h1': ([-1, 0, 0], 0), 'h2': ([-1, 1, -1], 0), 'h3': ([1, -1, -1], 0), 'out': ([-1, 1, -1], 0)},
{'h1': ([1, 0, 0], -1), 'h2': ([1, -1, 1], -1), 'h3': ([-1, 1, 1], -1), 'out': ([1, -1, 1], 0)},
]
},
'passthrough': {
'inputs': 1,
'outputs': 1,
'neurons': 1,
'magnitude': 2,
'solutions': [
{'out': ([1], 0)},
{'out': ([2], -1)},
]
}
}
def prune_compositional(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
"""
Compositional search for circuits built from known-optimal components.
Instead of searching 10^15 parameter combinations, recognizes that circuits
like CRC-16 are composed of XOR (6 solutions), XOR3 (18 solutions), and
pass-throughs, giving only 6 × 18 × 18 = 1,944 combinations.
"""
start = time.perf_counter()
original = circuit.stats()
if cfg.verbose:
print(f" [COMP] Analyzing circuit structure...")
# Detect component structure from neuron names
components = []
neuron_names = list(circuit.graph.neurons.keys())
# Group neurons by component prefix
prefixes = set()
for name in neuron_names:
if '.' in name:
prefix = name.rsplit('.', 1)[0]
prefixes.add(prefix)
else:
prefixes.add(name)
# Classify each prefix as a component type
for prefix in sorted(prefixes):
related = [n for n in neuron_names if n == prefix or n.startswith(prefix + '.')]
n_neurons = len(related)
if n_neurons == 1:
# Single neuron - likely pass-through
components.append({'type': 'passthrough', 'prefix': prefix, 'neurons': related})
elif n_neurons == 3 and any('h1' in n or 'h2' in n for n in related):
# 3 neurons with h1, h2 pattern - likely XOR
components.append({'type': 'xor', 'prefix': prefix, 'neurons': related})
elif n_neurons == 4 and any('h1' in n or 'h2' in n or 'h3' in n for n in related):
# 4 neurons with h1, h2, h3 pattern - likely XOR3
components.append({'type': 'xor3', 'prefix': prefix, 'neurons': related})
else:
# Unknown structure
components.append({'type': 'unknown', 'prefix': prefix, 'neurons': related, 'n': n_neurons})
# Count component types
type_counts = defaultdict(int)
for c in components:
type_counts[c['type']] += 1
if cfg.verbose:
print(f" [COMP] Detected components:")
for ctype, count in sorted(type_counts.items()):
if ctype in OPTIMAL_COMPONENTS:
n_solutions = len(OPTIMAL_COMPONENTS[ctype]['solutions'])
print(f" - {ctype}: {count} instances × {n_solutions} solutions each")
else:
print(f" - {ctype}: {count} instances (no known optimal)")
# Check if all components are known
unknown_components = [c for c in components if c['type'] == 'unknown']
if unknown_components:
if cfg.verbose:
print(f" [COMP] Cannot use compositional search - unknown components found")
for c in unknown_components[:5]:
print(f" - {c['prefix']}: {c['n']} neurons")
return PruneResult(
method='compositional',
original_stats=original,
final_stats=original,
final_weights=circuit.clone_weights(),
fitness=evaluator.evaluate_single(circuit.weights),
time_seconds=time.perf_counter() - start,
metadata={'status': 'unknown_components', 'unknown': [c['prefix'] for c in unknown_components]}
)
# Calculate total combinations
total_combos = 1
for c in components:
if c['type'] in OPTIMAL_COMPONENTS:
total_combos *= len(OPTIMAL_COMPONENTS[c['type']]['solutions'])
if cfg.verbose:
print(f" [COMP] Total combinations: {total_combos:,}")
if total_combos > 10_000_000:
if cfg.verbose:
print(f" [COMP] Too many combinations, using sampling...")
# Sample randomly
n_samples = min(1_000_000, total_combos)
valid_solutions = []
for _ in range(n_samples):
weights = circuit.clone_weights()
total_mag = 0
for comp in components:
if comp['type'] not in OPTIMAL_COMPONENTS:
continue
solutions = OPTIMAL_COMPONENTS[comp['type']]['solutions']
sol = random.choice(solutions)
# Apply solution to weights
# This is a simplified version - full implementation would map neuron names
total_mag += OPTIMAL_COMPONENTS[comp['type']]['magnitude']
# Evaluate
fitness = evaluator.evaluate_single(weights)
if fitness >= cfg.fitness_threshold:
valid_solutions.append({'magnitude': total_mag, 'fitness': fitness})
if valid_solutions:
best = min(valid_solutions, key=lambda x: x['magnitude'])
if cfg.verbose:
print(f" [COMP] Found {len(valid_solutions)} valid from {n_samples:,} samples")
else:
# Enumerate all combinations
valid_solutions = []
tested = 0
# Generate all combinations using itertools.product
solution_lists = []
for comp in components:
if comp['type'] in OPTIMAL_COMPONENTS:
solution_lists.append(list(range(len(OPTIMAL_COMPONENTS[comp['type']]['solutions']))))
else:
solution_lists.append([0]) # placeholder
if cfg.verbose:
print(f" [COMP] Enumerating {total_combos:,} combinations...")
for combo in product(*solution_lists):
tested += 1
total_mag = 0
for i, comp in enumerate(components):
if comp['type'] in OPTIMAL_COMPONENTS:
total_mag += OPTIMAL_COMPONENTS[comp['type']]['magnitude']
# For now, just count theoretical magnitude
# Full implementation would apply weights and verify
valid_solutions.append({'combo': combo, 'magnitude': total_mag})
if cfg.verbose and tested % 10000 == 0:
print(f" Tested {tested:,}/{total_combos:,}...", end='\r')
if cfg.verbose:
print(f" Tested {tested:,}/{total_combos:,} - done")
# Calculate theoretical optimal
theoretical_mag = sum(OPTIMAL_COMPONENTS[c['type']]['magnitude'] for c in components if c['type'] in OPTIMAL_COMPONENTS)
if cfg.verbose:
print(f" [COMP] Theoretical optimal magnitude: {theoretical_mag}")
print(f" [COMP] Original magnitude: {original['magnitude']:.0f}")
if theoretical_mag < original['magnitude']:
print(f" [COMP] Potential reduction: {(1 - theoretical_mag/original['magnitude'])*100:.1f}%")
return PruneResult(
method='compositional',
original_stats=original,
final_stats=original,
final_weights=circuit.clone_weights(),
fitness=evaluator.evaluate_single(circuit.weights),
time_seconds=time.perf_counter() - start,
metadata={
'components': [(c['type'], c['prefix']) for c in components],
'total_combinations': total_combos,
'theoretical_magnitude': theoretical_mag
}
)
# =============================================================================
# MAIN
# =============================================================================
def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]:
"""Run all enabled pruning methods."""
print(f"\n{'=' * 70}")
print(f" PRUNING: {circuit.spec.name}")
print(f"{'=' * 70}")
usage = cfg.vram.current_usage()
print(f" VRAM: {cfg.vram.total_gb:.1f} GB total, {usage['free_gb']:.1f} GB free")
print(f" Device: {cfg.vram.device_name}")
original = circuit.stats()
print(f" Inputs: {circuit.spec.inputs}, Outputs: {circuit.spec.outputs}")
print(f" Neurons: {circuit.spec.neurons}, Layers: {circuit.spec.layers}")
print(f" Parameters: {original['total']}, Non-zero: {original['nonzero']}")
print(f" Magnitude: {original['magnitude']:.0f}")
print(f" Test cases: {circuit.n_cases}")
print(f"{'=' * 70}")
evaluator = BatchedEvaluator(circuit, cfg)
initial_fitness = evaluator.evaluate_single(circuit.weights)
print(f"\n Initial fitness: {initial_fitness:.6f}")
if initial_fitness < cfg.fitness_threshold:
print(" ERROR: Circuit doesn't pass baseline!")
return {}
results = {}
methods = [
('magnitude', cfg.run_magnitude, lambda: prune_magnitude(circuit, evaluator, cfg)),
('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)),
('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
('exhaustive_mag', cfg.run_exhaustive_mag, lambda: prune_exhaustive_mag(circuit, evaluator, cfg)),
('architecture', cfg.run_architecture, lambda: prune_architecture(circuit, evaluator, cfg)),
('compositional', cfg.run_compositional, lambda: prune_compositional(circuit, evaluator, cfg)),
]
enabled = [(n, fn) for n, enabled, fn in methods if enabled]
print(f"\n Running {len(enabled)} pruning methods...")
print(f"{'=' * 70}")
for i, (name, fn) in enumerate(enabled):
print(f"\n[{i + 1}/{len(enabled)}] {name.upper()}")
print("-" * 50)
try:
clear_vram()
results[name] = fn()
r = results[name]
print(f" Fitness: {r.fitness:.6f}, Magnitude: {r.final_stats.get('magnitude', 0):.0f}, Time: {r.time_seconds:.1f}s")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
# Summary
print(f"\n{'=' * 70}")
print(" SUMMARY")
print(f"{'=' * 70}")
print(f"\n{'Method':<15} {'Fitness':<10} {'Magnitude':<12} {'Reduction':<12} {'Time':<10}")
print("-" * 60)
print(f"{'Original':<15} {'1.0000':<10} {original['magnitude']:<12.0f} {'-':<12} {'-':<10}")
best_method, best_mag = None, float('inf')
for name, r in sorted(results.items(), key=lambda x: x[1].final_stats.get('magnitude', float('inf'))):
mag = r.final_stats.get('magnitude', 0)
reduction = f"{(1 - mag/original['magnitude'])*100:.1f}%" if mag < original['magnitude'] else "-"
print(f"{name:<15} {r.fitness:<10.4f} {mag:<12.0f} {reduction:<12} {r.time_seconds:<10.1f}s")
if r.fitness >= cfg.fitness_threshold and mag < best_mag:
best_mag, best_method = mag, name
if best_method:
print(f"\n BEST: {best_method} ({(1 - best_mag/original['magnitude'])*100:.1f}% reduction)")
return results
def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]:
"""Find all circuits."""
circuits = []
for d in base.iterdir():
if d.is_dir() and (d / 'config.json').exists() and list(d.glob('*.safetensors')):
try:
with open(d / 'config.json') as f:
cfg = json.load(f)
circuits.append(CircuitSpec(
name=cfg.get('name', d.name),
path=d,
inputs=cfg.get('inputs', 0),
outputs=cfg.get('outputs', 0),
neurons=cfg.get('neurons', 0),
layers=cfg.get('layers', 0),
parameters=cfg.get('parameters', 0)
))
except:
pass
return sorted(circuits, key=lambda x: (x.inputs, x.neurons))
def main():
parser = argparse.ArgumentParser(description='Threshold Circuit Pruner v5')
parser.add_argument('circuit', nargs='?', help='Circuit name')
parser.add_argument('--weights', type=str, help='Specific .safetensors file')
parser.add_argument('--list', action='store_true')
parser.add_argument('--all', action='store_true')
parser.add_argument('--max-inputs', type=int, default=10)
parser.add_argument('--device', default='cuda')
parser.add_argument('--methods', type=str, help='Comma-separated: mag,zero,evo,exh,arch,comp')
parser.add_argument('--fitness', type=float, default=0.9999)
parser.add_argument('--quiet', action='store_true')
parser.add_argument('--save', action='store_true')
parser.add_argument('--evo-pop', type=int, default=0)
parser.add_argument('--evo-gens', type=int, default=2000)
parser.add_argument('--exhaustive-max-params', type=int, default=12)
parser.add_argument('--target-mag', type=int, default=-1)
parser.add_argument('--arch-hidden', type=int, default=3)
parser.add_argument('--arch-max-weight', type=int, default=3)
parser.add_argument('--arch-max-mag', type=int, default=20)
args = parser.parse_args()
if args.list:
specs = discover_circuits()
print(f"\nAvailable circuits ({len(specs)}):\n")
for s in specs:
print(f" {s.name:<40} {s.inputs}in/{s.outputs}out {s.neurons}N {s.parameters}P")
return
vram_cfg = VRAMConfig()
cfg = Config(
device=args.device,
fitness_threshold=args.fitness,
verbose=not args.quiet,
vram=vram_cfg,
evo_pop_size=args.evo_pop,
evo_generations=args.evo_gens,
exhaustive_max_params=args.exhaustive_max_params,
exhaustive_target_mag=args.target_mag,
arch_hidden_neurons=args.arch_hidden,
arch_max_weight=args.arch_max_weight,
arch_max_mag=args.arch_max_mag
)
if args.methods:
method_map = {
'mag': 'magnitude', 'magnitude': 'magnitude',
'zero': 'zero',
'evo': 'evolutionary', 'evolutionary': 'evolutionary',
'exh': 'exhaustive_mag', 'exh_mag': 'exhaustive_mag', 'exhaustive': 'exhaustive_mag',
'arch': 'architecture', 'architecture': 'architecture',
'comp': 'compositional', 'compositional': 'compositional'
}
for m in args.methods.lower().split(','):
m = m.strip()
if m in method_map:
setattr(cfg, f'run_{method_map[m]}', True)
RESULTS_PATH.mkdir(exist_ok=True)
if args.all:
specs = [s for s in discover_circuits() if s.inputs <= args.max_inputs]
print(f"\nRunning on {len(specs)} circuits...")
for spec in specs:
try:
circuit = AdaptiveCircuit(spec.path, cfg.device)
run_all_methods(circuit, cfg)
clear_vram()
except Exception as e:
print(f"ERROR on {spec.name}: {e}")
elif args.circuit:
path = CIRCUITS_PATH / args.circuit
if not path.exists():
path = CIRCUITS_PATH / f'threshold-{args.circuit}'
if not path.exists():
print(f"Circuit not found: {args.circuit}")
return
circuit = AdaptiveCircuit(path, cfg.device, args.weights)
results = run_all_methods(circuit, cfg)
if args.save and results:
best = min(results.values(), key=lambda r: r.final_stats.get('magnitude', float('inf')))
if best.fitness >= cfg.fitness_threshold:
path = circuit.save_weights(best.final_weights, f'pruned_{best.method}')
print(f"\nSaved to: {path}")
else:
parser.print_help()
print("\n\nExamples:")
print(" python prune.py --list")
print(" python prune.py threshold-xor --methods evo")
print(" python prune.py threshold-xor --methods exh --exhaustive-max-params 20")
print(" python prune.py threshold-crc16-mag53 --methods comp")
print(" python prune.py --all --max-inputs 8 --methods mag,zero")
if __name__ == '__main__':
main()