""" Threshold Circuit Pruner v5 (Refactored) Streamlined pruning framework with 6 core methods: - magnitude: Greedy weight reduction - zero: Try zeroing individual weights - evolutionary: GPU-parallel genetic algorithm - exhaustive_mag: Provably optimal for small circuits - architecture: Search flat 2-layer alternatives - compositional: For circuits built from known-optimal components Usage: python prune.py threshold-xor --methods evo python prune.py threshold-xor --methods exh_mag python prune.py threshold-crc16-mag53 --methods comp python prune.py --list """ import torch import torch.nn.functional as F import json import time import random import argparse import math import gc import re import importlib.util from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from dataclasses import dataclass, field from typing import Dict, List, Tuple, Optional, Callable from safetensors.torch import load_file, save_file from collections import defaultdict from functools import lru_cache from itertools import combinations, product import warnings warnings.filterwarnings('ignore') CIRCUITS_PATH = Path('D:/threshold-logic-circuits') RESULTS_PATH = CIRCUITS_PATH / 'pruned_results' @dataclass class VRAMConfig: target_residency: float = 0.75 safety_margin: float = 0.05 def __post_init__(self): self.total_bytes = 0 self.device_name = "CPU" if torch.cuda.is_available(): props = torch.cuda.get_device_properties(0) self.total_bytes = props.total_memory self.device_name = props.name @property def total_gb(self) -> float: return self.total_bytes / 1e9 @property def available_bytes(self) -> int: return int(self.total_bytes * (self.target_residency - self.safety_margin)) def current_usage(self) -> Dict: if not torch.cuda.is_available(): return {'allocated_gb': 0, 'free_gb': 0, 'utilization': 0} allocated = torch.cuda.memory_allocated() return { 'allocated_gb': allocated / 1e9, 'free_gb': (self.total_bytes - allocated) / 1e9, 'utilization': allocated / self.total_bytes if self.total_bytes > 0 else 0 } def clear_vram(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() @dataclass class Config: device: str = 'cuda' fitness_threshold: float = 0.9999 verbose: bool = True vram: VRAMConfig = field(default_factory=VRAMConfig) run_magnitude: bool = False run_zero: bool = False run_evolutionary: bool = False run_exhaustive_mag: bool = False run_architecture: bool = False run_compositional: bool = False magnitude_passes: int = 100 exhaustive_max_params: int = 12 exhaustive_target_mag: int = -1 evo_generations: int = 2000 evo_pop_size: int = 0 evo_elite_ratio: float = 0.05 evo_mutation_rate: float = 0.15 evo_mutation_strength: float = 2.0 evo_crossover_rate: float = 0.3 evo_parsimony: float = 0.001 arch_hidden_neurons: int = 3 arch_max_weight: int = 3 arch_max_mag: int = 20 @dataclass class CircuitSpec: name: str path: Path inputs: int outputs: int neurons: int layers: int parameters: int description: str = "" @dataclass class PruneResult: method: str original_stats: Dict final_stats: Dict final_weights: Dict[str, torch.Tensor] fitness: float time_seconds: float metadata: Dict = field(default_factory=dict) class ComputationGraph: """Parses weight structure to build dependency graph.""" def __init__(self, weights: Dict[str, torch.Tensor], n_inputs: int, n_outputs: int, device: str): self.device = device self.n_inputs = n_inputs self.n_outputs = n_outputs self.weights = weights self.neurons = {} self.neuron_order = [] self.output_neurons = [] self.layer_groups = defaultdict(list) self._parse_structure() self._build_execution_order() def _parse_structure(self): neuron_weights = defaultdict(dict) for key, tensor in self.weights.items(): if '.weight' in key: neuron_name = key.replace('.weight', '') neuron_weights[neuron_name]['weight'] = key neuron_weights[neuron_name]['weight_shape'] = tensor.shape elif '.bias' in key: neuron_name = key.replace('.bias', '') neuron_weights[neuron_name]['bias'] = key for neuron_name, params in neuron_weights.items(): if 'weight' in params: depth = self._estimate_depth(neuron_name) self.neurons[neuron_name] = { 'weight_key': params.get('weight'), 'bias_key': params.get('bias'), 'weight_shape': params.get('weight_shape', (1,)), 'input_size': params.get('weight_shape', (1,))[-1], 'depth': depth } self.layer_groups[depth].append(neuron_name) self._infer_depth_from_shapes() self._identify_outputs() def _estimate_depth(self, name: str) -> int: depth = 0 if 'layer' in name: match = re.search(r'layer(\d+)', name) if match: depth = int(match.group(1)) depth += len(name.split('.')) - 1 return depth def _infer_depth_from_shapes(self): for name, info in self.neurons.items(): if info.get('input_size', 0) == self.n_inputs: info['depth'] = 0 info['input_source'] = 'raw' self.layer_groups = defaultdict(list) for name, info in self.neurons.items(): self.layer_groups[info['depth']].append(name) def _identify_outputs(self): candidates = [] for name in self.neurons: is_parent = any(name + '.' in other for other in self.neurons if other != name) if not is_parent: candidates.append((name, self.neurons[name]['depth'])) candidates.sort(key=lambda x: (-x[1], x[0])) self.output_neurons = [c[0] for c in candidates[:self.n_outputs]] def _build_execution_order(self): self.neuron_order = sorted(self.neurons.keys(), key=lambda n: (self.neurons[n]['depth'], n)) def forward_single(self, inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor: activations = {'input': inputs} for neuron_name in self.neuron_order: info = self.neurons[neuron_name] w_key, b_key = info['weight_key'], info['bias_key'] if w_key and w_key in weights: w = weights[w_key] if w.dim() == 1: w = w.unsqueeze(0) inp = self._get_neuron_input(neuron_name, activations, inputs, w.shape[-1]) if inp.dim() == 1: out = inp @ w.flatten() else: out = (inp.unsqueeze(-2) @ w.T.unsqueeze(0)).squeeze(-2) if out.dim() > 1 and out.shape[-1] == 1: out = out.squeeze(-1) if b_key and b_key in weights: out = out + weights[b_key].squeeze() activations[neuron_name] = (out >= 0).float() return self._collect_outputs(activations) def _get_neuron_input(self, neuron_name: str, activations: Dict, raw_input: torch.Tensor, expected_size: int) -> torch.Tensor: info = self.neurons.get(neuron_name, {}) if info.get('input_source') == 'raw' or expected_size == self.n_inputs: return raw_input if 'layer2' in neuron_name or '.out' in neuron_name: base = neuron_name.replace('.layer2', '').replace('.out', '') hidden_keys = [k for k in activations if k.startswith(base) and k != neuron_name and k != 'input'] if len(hidden_keys) == expected_size: return torch.stack([activations[k] for k in sorted(hidden_keys)], dim=-1) return raw_input[..., :expected_size] if raw_input.shape[-1] >= expected_size else raw_input def _collect_outputs(self, activations: Dict) -> torch.Tensor: outputs = [activations[n] for n in sorted(self.output_neurons) if n in activations] if outputs: return torch.stack(outputs, dim=-1) if outputs[0].dim() > 0 else torch.stack(outputs) return torch.zeros(self.n_outputs, device=self.device) class AdaptiveCircuit: """Adaptive threshold circuit with automatic evaluation.""" def __init__(self, path: Path, device: str = 'cuda', weights_file: str = None): self.path = Path(path) self.device = device self.spec = self._load_spec() self.weights = self._load_weights(weights_file) self.weight_keys = list(self.weights.keys()) self.n_weights = sum(t.numel() for t in self.weights.values()) self.native_forward = self._try_load_native_forward() self.has_native = self.native_forward is not None print(f" [LOAD] Native forward: {'FOUND' if self.has_native else 'NOT FOUND'}") print(f" [LOAD] Parsing circuit topology...") self.graph = ComputationGraph(self.weights, self.spec.inputs, self.spec.outputs, device) print(f" [LOAD] Found {len(self.graph.neurons)} neurons across {len(self.graph.layer_groups)} layers") print(f" [LOAD] Output neurons: {sorted(self.graph.output_neurons)}") print(f" [LOAD] Building test cases...") self.test_inputs, self.test_expected = self._build_tests() self.n_cases = self.test_inputs.shape[0] print(f" [LOAD] Generated {self.n_cases} test cases") self._compile_fast_forward() print(f" [LOAD] Circuit ready: {self.n_weights} weight parameters") def _try_load_native_forward(self) -> Optional[Callable]: model_py = self.path / 'model.py' if not model_py.exists(): return None try: spec = importlib.util.spec_from_file_location("circuit_model", model_py) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) if hasattr(module, 'forward'): return module.forward return None except Exception: return None def evaluate_native(self, inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor: if not self.has_native: return self.graph.forward_single(inputs, weights) try: out = self.native_forward(inputs, weights) if isinstance(out, torch.Tensor): return out.to(self.device) if out.dim() > 1 else out.unsqueeze(-1).to(self.device) except Exception: pass cpu_weights = {k: v.cpu() for k, v in weights.items()} results = [] for i in range(inputs.shape[0]): inp = [int(x) for x in inputs[i].cpu().tolist()] try: out = self.native_forward(inp, cpu_weights) results.append([float(x) for x in out] if isinstance(out, (list, tuple)) else [float(out)]) except Exception: results.append([0.0] * self.spec.outputs) return torch.tensor(results, device=self.device, dtype=torch.float32) def _load_spec(self) -> CircuitSpec: with open(self.path / 'config.json') as f: cfg = json.load(f) return CircuitSpec( name=cfg.get('name', self.path.name), path=self.path, inputs=cfg.get('inputs', cfg.get('input_size', 0)), outputs=cfg.get('outputs', cfg.get('output_size', 0)), neurons=cfg.get('neurons', 0), layers=cfg.get('layers', 0), parameters=cfg.get('parameters', 0), description=cfg.get('description', '') ) def _load_weights(self, weights_file: str = None) -> Dict[str, torch.Tensor]: if weights_file: sf = self.path / weights_file else: sf = self.path / 'model.safetensors' if not sf.exists(): candidates = list(self.path.glob('*.safetensors')) sf = candidates[0] if candidates else sf w = load_file(str(sf)) return {k: v.float().to(self.device) for k, v in w.items()} def _build_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: if self.has_native: return self._build_native_tests() return self._build_exhaustive_tests() def _build_exhaustive_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: n = self.spec.inputs if n > 24: raise ValueError(f"Input space too large: 2^{n}") n_cases = 2 ** n idx = torch.arange(n_cases, device=self.device, dtype=torch.long) bits = torch.arange(n, device=self.device, dtype=torch.long) inputs = ((idx.unsqueeze(1) >> bits) & 1).float() expected = self.graph.forward_single(inputs, self.weights) return inputs, expected def _build_native_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: n = self.spec.inputs if n > 20: raise ValueError(f"Input space too large: 2^{n}") n_cases = 2 ** n inputs_list, expected_list = [], [] cpu_weights = {k: v.cpu() for k, v in self.weights.items()} for i in range(n_cases): inp = [(i >> b) & 1 for b in range(n)] inputs_list.append(inp) out = self.native_forward(inp, cpu_weights) expected_list.append([float(x) for x in out] if isinstance(out, (list, tuple)) else [float(out)]) return (torch.tensor(inputs_list, device=self.device, dtype=torch.float32), torch.tensor(expected_list, device=self.device, dtype=torch.float32)) def _compile_fast_forward(self): self.weight_layout = [] offset = 0 for key in self.weight_keys: size = self.weights[key].numel() self.weight_layout.append((key, offset, offset + size, self.weights[key].shape)) offset += size self.base_vector = self.weights_to_vector(self.weights) def weights_to_vector(self, weights: Dict[str, torch.Tensor]) -> torch.Tensor: return torch.cat([weights[k].flatten() for k in self.weight_keys]) def vector_to_weights(self, vector: torch.Tensor) -> Dict[str, torch.Tensor]: weights = {} for key, start, end, shape in self.weight_layout: weights[key] = vector[start:end].view(shape) return weights def clone_weights(self) -> Dict[str, torch.Tensor]: return {k: v.clone() for k, v in self.weights.items()} def stats(self, weights: Dict[str, torch.Tensor] = None) -> Dict: w = weights or self.weights total = sum(t.numel() for t in w.values()) nonzero = sum((t != 0).sum().item() for t in w.values()) mag = sum(t.abs().sum().item() for t in w.values()) maxw = max(t.abs().max().item() for t in w.values()) if w else 0 return { 'total': total, 'nonzero': nonzero, 'sparsity': 1 - nonzero / total if total else 0, 'magnitude': mag, 'max_weight': maxw } def save_weights(self, weights: Dict[str, torch.Tensor], suffix: str = 'pruned') -> Path: path = self.path / f'model_{suffix}.safetensors' save_file({k: v.cpu() for k, v in weights.items()}, str(path)) return path class BatchedEvaluator: """GPU-optimized batched population evaluation.""" def __init__(self, circuit: AdaptiveCircuit, cfg: Config): self.circuit = circuit self.cfg = cfg self.device = cfg.device self.test_inputs = circuit.test_inputs self.test_expected = circuit.test_expected self.n_cases = circuit.n_cases self.n_weights = circuit.n_weights if cfg.verbose: print(f" [EVAL] Initializing evaluator...") self._calculate_batch_size() self._validate_evaluation() if cfg.verbose: print(f" [EVAL] Evaluator ready: batch={self.max_batch:,}, native={circuit.has_native}") def _calculate_batch_size(self): bytes_per_ind = self.n_weights * 4 * 2 + self.n_cases * self.circuit.spec.outputs * 4 + 4096 available = self.cfg.vram.available_bytes self.max_batch = max(1000, min(available // max(bytes_per_ind, 1), 5_000_000)) def _validate_evaluation(self): fitness = self.evaluate_single(self.circuit.weights) if fitness < 0.999 and self.cfg.verbose: print(f" [EVAL WARNING] Original weights fitness={fitness:.4f}") def evaluate_single(self, weights: Dict[str, torch.Tensor]) -> float: with torch.no_grad(): if self.circuit.has_native: outputs = self.circuit.evaluate_native(self.test_inputs, weights) else: outputs = self.circuit.graph.forward_single(self.test_inputs, weights) if outputs.shape != self.test_expected.shape: if outputs.dim() == 1: outputs = outputs.unsqueeze(0).expand(self.test_expected.shape[0], -1) correct = (outputs == self.test_expected).all(dim=-1).float().sum() return (correct / self.n_cases).item() def evaluate_population(self, population: torch.Tensor) -> torch.Tensor: pop_size = population.shape[0] if pop_size > self.max_batch: return self._evaluate_chunked(population) return self._evaluate_sequential(population) def _evaluate_sequential(self, population: torch.Tensor) -> torch.Tensor: pop_size = population.shape[0] fitness = torch.zeros(pop_size, device=self.device) with torch.no_grad(): for i in range(pop_size): weights = self.circuit.vector_to_weights(population[i]) if self.circuit.has_native: outputs = self.circuit.evaluate_native(self.test_inputs, weights) else: outputs = self.circuit.graph.forward_single(self.test_inputs, weights) if outputs.shape != self.test_expected.shape: if outputs.dim() == 1: outputs = outputs.unsqueeze(0).expand(self.test_expected.shape[0], -1) if outputs.shape == self.test_expected.shape: correct = (outputs == self.test_expected).all(dim=-1).float().sum() fitness[i] = correct / self.n_cases return fitness def _evaluate_chunked(self, population: torch.Tensor) -> torch.Tensor: pop_size = population.shape[0] fitness = torch.zeros(pop_size, device=self.device) for start in range(0, pop_size, self.max_batch): end = min(start + self.max_batch, pop_size) fitness[start:end] = self._evaluate_sequential(population[start:end]) return fitness # ============================================================================= # PRUNING METHODS # ============================================================================= def prune_magnitude(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: """Iterative magnitude reduction - try reducing each weight toward zero.""" start = time.perf_counter() weights = circuit.clone_weights() original = circuit.stats(weights) if cfg.verbose: print(f" Starting magnitude reduction (passes={cfg.magnitude_passes})...") for pass_num in range(cfg.magnitude_passes): candidates = [] for name, tensor in weights.items(): flat = tensor.flatten() for i in range(len(flat)): val = flat[i].item() if val != 0: new_val = val - 1 if val > 0 else val + 1 candidates.append((name, i, tensor.shape, val, new_val)) if not candidates: break random.shuffle(candidates) reductions = 0 for name, idx, shape, old_val, new_val in candidates: flat = weights[name].flatten() flat[idx] = new_val weights[name] = flat.view(shape) if evaluator.evaluate_single(weights) >= cfg.fitness_threshold: reductions += 1 else: flat[idx] = old_val weights[name] = flat.view(shape) if cfg.verbose: stats = circuit.stats(weights) print(f" Pass {pass_num}: {reductions} reductions, mag={stats['magnitude']:.0f}") if reductions == 0: break return PruneResult( method='magnitude', original_stats=original, final_stats=circuit.stats(weights), final_weights=weights, fitness=evaluator.evaluate_single(weights), time_seconds=time.perf_counter() - start ) def prune_zero(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: """Zero pruning - try setting each non-zero weight to zero.""" start = time.perf_counter() weights = circuit.clone_weights() original = circuit.stats(weights) candidates = [] for name, tensor in weights.items(): flat = tensor.flatten() for i in range(len(flat)): if flat[i].item() != 0: candidates.append((name, i, tensor.shape, flat[i].item())) random.shuffle(candidates) if cfg.verbose: print(f" Testing {len(candidates)} non-zero weights for zeroing...") zeroed = 0 for name, idx, shape, old_val in candidates: flat = weights[name].flatten() flat[idx] = 0 weights[name] = flat.view(shape) if evaluator.evaluate_single(weights) >= cfg.fitness_threshold: zeroed += 1 else: flat[idx] = old_val weights[name] = flat.view(shape) if cfg.verbose: print(f" Zeroed {zeroed}/{len(candidates)} weights") return PruneResult( method='zero', original_stats=original, final_stats=circuit.stats(weights), final_weights=weights, fitness=evaluator.evaluate_single(weights), time_seconds=time.perf_counter() - start ) def prune_evolutionary(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: """Evolutionary search with GPU-optimized parallel population evaluation.""" start = time.perf_counter() original = circuit.stats() if cfg.evo_pop_size > 0: pop_size = cfg.evo_pop_size else: bytes_per_ind = circuit.n_weights * 4 * 3 available = cfg.vram.available_bytes - torch.cuda.memory_allocated() if torch.cuda.is_available() else cfg.vram.available_bytes pop_size = max(10000, min(available // max(bytes_per_ind, 1), evaluator.max_batch, 500000)) elite_size = max(1, int(pop_size * cfg.evo_elite_ratio)) if cfg.verbose: print(f" [EVO] Population: {pop_size:,}, Elite: {elite_size:,}, Generations: {cfg.evo_generations}") base_vector = circuit.weights_to_vector(circuit.weights) population = base_vector.unsqueeze(0).expand(pop_size, -1).clone() # Initialize with varied mutations n_exact = max(elite_size, pop_size // 10) for i in range(n_exact, pop_size // 2): n_muts = max(1, circuit.n_weights // 10) mut_idx = torch.randperm(circuit.n_weights)[:n_muts] population[i, mut_idx] += torch.randint(-1, 2, (n_muts,), device=cfg.device, dtype=population.dtype) for i in range(pop_size // 2, pop_size): n_muts = max(1, circuit.n_weights // 4) mut_idx = torch.randperm(circuit.n_weights)[:n_muts] population[i, mut_idx] += torch.randint(-2, 3, (n_muts,), device=cfg.device, dtype=population.dtype) best_weights = circuit.clone_weights() best_fitness = evaluator.evaluate_single(best_weights) best_mag = original['magnitude'] best_score = best_fitness - cfg.evo_parsimony * best_mag / circuit.n_weights if best_fitness >= cfg.fitness_threshold else -float('inf') stagnant = 0 mutation_rate = cfg.evo_mutation_rate for gen in range(cfg.evo_generations): fitness = evaluator.evaluate_population(population) magnitudes = population.abs().sum(dim=1) adjusted = fitness - cfg.evo_parsimony * magnitudes / circuit.n_weights valid_mask = fitness >= cfg.fitness_threshold n_valid = valid_mask.sum().item() if n_valid > 0: valid_adjusted = adjusted.clone() valid_adjusted[~valid_mask] = -float('inf') best_idx = valid_adjusted.argmax().item() if adjusted[best_idx] > best_score: best_score = adjusted[best_idx].item() best_fitness = fitness[best_idx].item() best_weights = circuit.vector_to_weights(population[best_idx].clone()) best_mag = magnitudes[best_idx].item() stagnant = 0 else: stagnant += 1 else: stagnant += 1 # Adaptive mutation if stagnant > 50: mutation_rate = min(0.5, mutation_rate * 1.1) elif stagnant == 0: mutation_rate = max(0.01, mutation_rate * 0.95) if cfg.verbose and (gen % 50 == 0 or gen == cfg.evo_generations - 1): print(f" Gen {gen:4d} | valid: {n_valid:6,}/{pop_size:,} | mag: {best_mag:.0f} | stag: {stagnant}") # Selection and reproduction sorted_idx = adjusted.argsort(descending=True) elite = population[sorted_idx[:elite_size]].clone() probs = F.softmax(adjusted * 10, dim=0) parent_idx = torch.multinomial(probs, pop_size - elite_size, replacement=True) children = population[parent_idx].clone() # Crossover if cfg.evo_crossover_rate > 0: cross_mask = torch.rand(len(children), device=cfg.device) < cfg.evo_crossover_rate cross_idx = torch.where(cross_mask)[0] for i in range(0, len(cross_idx) - 1, 2): p1, p2 = cross_idx[i].item(), cross_idx[i + 1].item() point = random.randint(1, circuit.n_weights - 1) temp = children[p1, point:].clone() children[p1, point:] = children[p2, point:] children[p2, point:] = temp # Mutation mut_mask = torch.rand_like(children) < mutation_rate mutations = torch.randint(-int(cfg.evo_mutation_strength), int(cfg.evo_mutation_strength) + 1, children.shape, device=cfg.device, dtype=children.dtype) children = children + mut_mask * mutations population = torch.cat([elite, children], dim=0) if stagnant > 200: if cfg.verbose: print(f" [EVO] Early stop at generation {gen}") break if cfg.verbose: final_stats = circuit.stats(best_weights) print(f" [EVO] Final: mag={final_stats['magnitude']:.0f} (was {original['magnitude']:.0f})") return PruneResult( method='evolutionary', original_stats=original, final_stats=circuit.stats(best_weights), final_weights=best_weights, fitness=best_fitness, time_seconds=time.perf_counter() - start, metadata={'generations': gen + 1, 'population_size': pop_size} ) def _partitions(total: int, n: int, max_val: int): """Generate all ways to partition 'total' into 'n' non-negative integers <= max_val.""" if n == 0: if total == 0: yield [] return for i in range(min(total, max_val) + 1): for rest in _partitions(total - i, n - 1, max_val): yield [i] + rest def _all_signs(abs_vals: list): """Generate all sign combinations for absolute values.""" if not abs_vals: yield [] return for rest in _all_signs(abs_vals[1:]): if abs_vals[0] == 0: yield [0] + rest else: yield [abs_vals[0]] + rest yield [-abs_vals[0]] + rest def _configs_at_magnitude(mag: int, n_params: int): """Generate all n_params-length configs with given total magnitude.""" for partition in _partitions(mag, n_params, mag): for signed in _all_signs(partition): yield tuple(signed) def prune_exhaustive_mag(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: """Exhaustive search by magnitude - finds provably optimal solutions.""" start = time.perf_counter() original = circuit.stats() n_params = original['total'] max_mag = int(original['magnitude']) target_mag = cfg.exhaustive_target_mag if n_params > cfg.exhaustive_max_params: if cfg.verbose: print(f" [EXHAUSTIVE] Skipping: {n_params} params > max {cfg.exhaustive_max_params}") return PruneResult( method='exhaustive_mag', original_stats=original, final_stats=original, final_weights=circuit.clone_weights(), fitness=evaluator.evaluate_single(circuit.weights), time_seconds=time.perf_counter() - start, metadata={'skipped': True} ) if cfg.verbose: print(f" [EXHAUSTIVE] Parameters: {n_params}, Original magnitude: {max_mag}") if target_mag >= 0: print(f" [EXHAUSTIVE] Target magnitude: {target_mag}") weight_keys = list(circuit.weights.keys()) weight_shapes = {k: circuit.weights[k].shape for k in weight_keys} weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys} def vector_to_weights(vec): weights = {} idx = 0 for k in weight_keys: size = weight_sizes[k] weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k]) idx += size return weights all_solutions = [] optimal_mag = None total_tested = 0 mag_range = [target_mag] if target_mag >= 0 else range(0, max_mag + 1) for mag in mag_range: configs = list(_configs_at_magnitude(mag, n_params)) if not configs: continue if cfg.verbose: print(f" Magnitude {mag}: {len(configs):,} configurations...", end=" ", flush=True) valid = [] batch_size = min(100000, len(configs)) for batch_start in range(0, len(configs), batch_size): batch = configs[batch_start:batch_start + batch_size] population = torch.tensor(batch, dtype=torch.float32, device=cfg.device) try: fitness = evaluator.evaluate_population(population) except: fitness = torch.tensor([evaluator.evaluate_single(vector_to_weights(c)) for c in batch], device=cfg.device) for i, is_valid in enumerate((fitness >= cfg.fitness_threshold).tolist()): if is_valid: valid.append(batch[i]) total_tested += len(configs) if valid: if cfg.verbose: print(f"FOUND {len(valid)} solutions!") optimal_mag = mag all_solutions = valid if cfg.verbose and len(valid) <= 50: print(f" Solutions:") for i, sol in enumerate(valid[:20]): nz = sum(1 for v in sol if v != 0) print(f" {i+1}: mag={sum(abs(v) for v in sol)}, nz={nz}, {sol}") break else: if cfg.verbose: print("none") if all_solutions: best_weights = vector_to_weights(all_solutions[0]) best_fitness = evaluator.evaluate_single(best_weights) else: best_weights = circuit.clone_weights() best_fitness = evaluator.evaluate_single(best_weights) optimal_mag = max_mag if cfg.verbose: print(f" [EXHAUSTIVE] Tested: {total_tested:,}, Optimal: {optimal_mag}, Solutions: {len(all_solutions)}") return PruneResult( method='exhaustive_mag', original_stats=original, final_stats=circuit.stats(best_weights), final_weights=best_weights, fitness=best_fitness, time_seconds=time.perf_counter() - start, metadata={'optimal_magnitude': optimal_mag, 'solutions_count': len(all_solutions), 'all_solutions': all_solutions[:100]} ) def prune_architecture(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: """Architecture search - find optimal flat 2-layer architecture.""" start = time.perf_counter() original = circuit.stats() n_hidden = cfg.arch_hidden_neurons n_inputs = circuit.spec.inputs n_outputs = circuit.spec.outputs max_weight = cfg.arch_max_weight max_mag = cfg.arch_max_mag n_params = n_hidden * (n_inputs + 1) + n_outputs * (n_hidden + 1) if cfg.verbose: print(f" [ARCH] Hidden: {n_hidden}, Params: {n_params}, Max magnitude: {max_mag}") test_inputs = circuit.test_inputs test_expected = circuit.test_expected def eval_flat(configs: torch.Tensor) -> torch.Tensor: batch_size = configs.shape[0] idx = 0 hidden_w, hidden_b = [], [] for _ in range(n_hidden): hidden_w.append(configs[:, idx:idx+n_inputs]) idx += n_inputs hidden_b.append(configs[:, idx:idx+1]) idx += 1 output_w, output_b = [], [] for _ in range(n_outputs): output_w.append(configs[:, idx:idx+n_hidden]) idx += n_hidden output_b.append(configs[:, idx:idx+1]) idx += 1 hidden_acts = [] for h in range(n_hidden): act = (hidden_w[h].unsqueeze(1) * test_inputs.unsqueeze(0)).sum(dim=2) + hidden_b[h] hidden_acts.append((act >= 0).float()) hidden_stack = torch.stack(hidden_acts, dim=2) outputs = [] for o in range(n_outputs): out = (hidden_stack * output_w[o].unsqueeze(1)).sum(dim=2) + output_b[o] outputs.append((out >= 0).float()) if n_outputs == 1: predicted = outputs[0] expected = test_expected.squeeze() else: predicted = torch.stack(outputs, dim=2) expected = test_expected correct = (predicted == expected.unsqueeze(0)).float().mean(dim=1) if n_outputs > 1: correct = correct.mean(dim=1) return correct @lru_cache(maxsize=None) def partitions(total: int, n_slots: int, max_val: int) -> list: if n_slots == 0: return [()] if total == 0 else [] if n_slots == 1: return [(total,)] if total <= max_val else [] result = [] for v in range(min(total, max_val) + 1): for rest in partitions(total - v, n_slots - 1, max_val): result.append((v,) + rest) return result def signs_for_partition(partition: tuple) -> torch.Tensor: n = len(partition) nonzero_idx = [i for i, v in enumerate(partition) if v != 0] k = len(nonzero_idx) if k == 0: return torch.zeros(1, n, device=cfg.device, dtype=torch.float32) n_patterns = 2 ** k configs = torch.zeros(n_patterns, n, device=cfg.device, dtype=torch.float32) for i, idx in enumerate(nonzero_idx): signs = ((torch.arange(n_patterns, device=cfg.device) >> i) & 1) * 2 - 1 configs[:, idx] = signs.float() * partition[idx] return configs all_solutions = [] optimal_mag = None total_tested = 0 for target_mag in range(1, max_mag + 1): all_configs = [] for partition in partitions(target_mag, n_params, max_weight): all_configs.append(signs_for_partition(partition)) if not all_configs: continue configs = torch.cat(all_configs, dim=0) if cfg.verbose: print(f" Magnitude {target_mag}: {configs.shape[0]:,} configs...", end=" ", flush=True) valid = [] for i in range(0, configs.shape[0], 500000): batch = configs[i:i+500000] fitness = eval_flat(batch) valid.extend(batch[fitness >= cfg.fitness_threshold].cpu().tolist()) total_tested += configs.shape[0] if valid: if cfg.verbose: print(f"FOUND {len(valid)} solutions!") optimal_mag = target_mag all_solutions = valid break else: if cfg.verbose: print("none") if cfg.verbose: print(f" [ARCH] Tested: {total_tested:,}, Optimal: {optimal_mag}, Solutions: {len(all_solutions)}") return PruneResult( method='architecture', original_stats=original, final_stats=original, final_weights=circuit.clone_weights(), fitness=evaluator.evaluate_single(circuit.weights), time_seconds=time.perf_counter() - start, metadata={'optimal_magnitude': optimal_mag, 'solutions_count': len(all_solutions)} ) # ============================================================================= # COMPOSITIONAL SEARCH - For circuits built from known-optimal components # ============================================================================= # Known optimal component families OPTIMAL_COMPONENTS = { 'xor': { 'inputs': 2, 'outputs': 1, 'neurons': 3, 'magnitude': 7, 'solutions': [ # Each solution: [(w1, b1), (w2, b2), (out_w, out_b)] # h1, h2 read raw inputs; out reads h1, h2 {'h1': ([-1, 1], 0), 'h2': ([1, -1], 0), 'out': ([-1, -1], 1)}, {'h1': ([1, -1], 0), 'h2': ([-1, 1], 0), 'out': ([-1, -1], 1)}, {'h1': ([-1, 1], 0), 'h2': ([-1, 1], 0), 'out': ([1, -1], 0)}, {'h1': ([1, -1], 0), 'h2': ([1, -1], 0), 'out': ([-1, 1], 0)}, {'h1': ([1, 1], -1), 'h2': ([-1, -1], 1), 'out': ([1, 1], -1)}, {'h1': ([-1, -1], 1), 'h2': ([1, 1], -1), 'out': ([1, 1], -1)}, ] }, 'xor3': { 'inputs': 3, 'outputs': 1, 'neurons': 4, 'magnitude': 10, 'solutions': [ # 18 known solutions at magnitude 10 {'h1': ([0, 0, -1], 0), 'h2': ([-1, 1, -1], 0), 'h3': ([-1, -1, 1], 0), 'out': ([1, -1, -1], 0)}, {'h1': ([0, 0, 1], -1), 'h2': ([1, -1, 1], -1), 'h3': ([1, 1, -1], -1), 'out': ([-1, 1, 1], 0)}, {'h1': ([0, -1, 0], 0), 'h2': ([-1, -1, 1], 0), 'h3': ([1, -1, -1], 0), 'out': ([-1, -1, 1], 0)}, {'h1': ([0, 1, 0], -1), 'h2': ([1, 1, -1], -1), 'h3': ([-1, 1, 1], -1), 'out': ([1, 1, -1], 0)}, {'h1': ([-1, 0, 0], 0), 'h2': ([-1, 1, -1], 0), 'h3': ([1, -1, -1], 0), 'out': ([-1, 1, -1], 0)}, {'h1': ([1, 0, 0], -1), 'h2': ([1, -1, 1], -1), 'h3': ([-1, 1, 1], -1), 'out': ([1, -1, 1], 0)}, ] }, 'passthrough': { 'inputs': 1, 'outputs': 1, 'neurons': 1, 'magnitude': 2, 'solutions': [ {'out': ([1], 0)}, {'out': ([2], -1)}, ] } } def prune_compositional(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: """ Compositional search for circuits built from known-optimal components. Instead of searching 10^15 parameter combinations, recognizes that circuits like CRC-16 are composed of XOR (6 solutions), XOR3 (18 solutions), and pass-throughs, giving only 6 × 18 × 18 = 1,944 combinations. """ start = time.perf_counter() original = circuit.stats() if cfg.verbose: print(f" [COMP] Analyzing circuit structure...") # Detect component structure from neuron names components = [] neuron_names = list(circuit.graph.neurons.keys()) # Group neurons by component prefix prefixes = set() for name in neuron_names: if '.' in name: prefix = name.rsplit('.', 1)[0] prefixes.add(prefix) else: prefixes.add(name) # Classify each prefix as a component type for prefix in sorted(prefixes): related = [n for n in neuron_names if n == prefix or n.startswith(prefix + '.')] n_neurons = len(related) if n_neurons == 1: # Single neuron - likely pass-through components.append({'type': 'passthrough', 'prefix': prefix, 'neurons': related}) elif n_neurons == 3 and any('h1' in n or 'h2' in n for n in related): # 3 neurons with h1, h2 pattern - likely XOR components.append({'type': 'xor', 'prefix': prefix, 'neurons': related}) elif n_neurons == 4 and any('h1' in n or 'h2' in n or 'h3' in n for n in related): # 4 neurons with h1, h2, h3 pattern - likely XOR3 components.append({'type': 'xor3', 'prefix': prefix, 'neurons': related}) else: # Unknown structure components.append({'type': 'unknown', 'prefix': prefix, 'neurons': related, 'n': n_neurons}) # Count component types type_counts = defaultdict(int) for c in components: type_counts[c['type']] += 1 if cfg.verbose: print(f" [COMP] Detected components:") for ctype, count in sorted(type_counts.items()): if ctype in OPTIMAL_COMPONENTS: n_solutions = len(OPTIMAL_COMPONENTS[ctype]['solutions']) print(f" - {ctype}: {count} instances × {n_solutions} solutions each") else: print(f" - {ctype}: {count} instances (no known optimal)") # Check if all components are known unknown_components = [c for c in components if c['type'] == 'unknown'] if unknown_components: if cfg.verbose: print(f" [COMP] Cannot use compositional search - unknown components found") for c in unknown_components[:5]: print(f" - {c['prefix']}: {c['n']} neurons") return PruneResult( method='compositional', original_stats=original, final_stats=original, final_weights=circuit.clone_weights(), fitness=evaluator.evaluate_single(circuit.weights), time_seconds=time.perf_counter() - start, metadata={'status': 'unknown_components', 'unknown': [c['prefix'] for c in unknown_components]} ) # Calculate total combinations total_combos = 1 for c in components: if c['type'] in OPTIMAL_COMPONENTS: total_combos *= len(OPTIMAL_COMPONENTS[c['type']]['solutions']) if cfg.verbose: print(f" [COMP] Total combinations: {total_combos:,}") if total_combos > 10_000_000: if cfg.verbose: print(f" [COMP] Too many combinations, using sampling...") # Sample randomly n_samples = min(1_000_000, total_combos) valid_solutions = [] for _ in range(n_samples): weights = circuit.clone_weights() total_mag = 0 for comp in components: if comp['type'] not in OPTIMAL_COMPONENTS: continue solutions = OPTIMAL_COMPONENTS[comp['type']]['solutions'] sol = random.choice(solutions) # Apply solution to weights # This is a simplified version - full implementation would map neuron names total_mag += OPTIMAL_COMPONENTS[comp['type']]['magnitude'] # Evaluate fitness = evaluator.evaluate_single(weights) if fitness >= cfg.fitness_threshold: valid_solutions.append({'magnitude': total_mag, 'fitness': fitness}) if valid_solutions: best = min(valid_solutions, key=lambda x: x['magnitude']) if cfg.verbose: print(f" [COMP] Found {len(valid_solutions)} valid from {n_samples:,} samples") else: # Enumerate all combinations valid_solutions = [] tested = 0 # Generate all combinations using itertools.product solution_lists = [] for comp in components: if comp['type'] in OPTIMAL_COMPONENTS: solution_lists.append(list(range(len(OPTIMAL_COMPONENTS[comp['type']]['solutions'])))) else: solution_lists.append([0]) # placeholder if cfg.verbose: print(f" [COMP] Enumerating {total_combos:,} combinations...") for combo in product(*solution_lists): tested += 1 total_mag = 0 for i, comp in enumerate(components): if comp['type'] in OPTIMAL_COMPONENTS: total_mag += OPTIMAL_COMPONENTS[comp['type']]['magnitude'] # For now, just count theoretical magnitude # Full implementation would apply weights and verify valid_solutions.append({'combo': combo, 'magnitude': total_mag}) if cfg.verbose and tested % 10000 == 0: print(f" Tested {tested:,}/{total_combos:,}...", end='\r') if cfg.verbose: print(f" Tested {tested:,}/{total_combos:,} - done") # Calculate theoretical optimal theoretical_mag = sum(OPTIMAL_COMPONENTS[c['type']]['magnitude'] for c in components if c['type'] in OPTIMAL_COMPONENTS) if cfg.verbose: print(f" [COMP] Theoretical optimal magnitude: {theoretical_mag}") print(f" [COMP] Original magnitude: {original['magnitude']:.0f}") if theoretical_mag < original['magnitude']: print(f" [COMP] Potential reduction: {(1 - theoretical_mag/original['magnitude'])*100:.1f}%") return PruneResult( method='compositional', original_stats=original, final_stats=original, final_weights=circuit.clone_weights(), fitness=evaluator.evaluate_single(circuit.weights), time_seconds=time.perf_counter() - start, metadata={ 'components': [(c['type'], c['prefix']) for c in components], 'total_combinations': total_combos, 'theoretical_magnitude': theoretical_mag } ) # ============================================================================= # MAIN # ============================================================================= def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]: """Run all enabled pruning methods.""" print(f"\n{'=' * 70}") print(f" PRUNING: {circuit.spec.name}") print(f"{'=' * 70}") usage = cfg.vram.current_usage() print(f" VRAM: {cfg.vram.total_gb:.1f} GB total, {usage['free_gb']:.1f} GB free") print(f" Device: {cfg.vram.device_name}") original = circuit.stats() print(f" Inputs: {circuit.spec.inputs}, Outputs: {circuit.spec.outputs}") print(f" Neurons: {circuit.spec.neurons}, Layers: {circuit.spec.layers}") print(f" Parameters: {original['total']}, Non-zero: {original['nonzero']}") print(f" Magnitude: {original['magnitude']:.0f}") print(f" Test cases: {circuit.n_cases}") print(f"{'=' * 70}") evaluator = BatchedEvaluator(circuit, cfg) initial_fitness = evaluator.evaluate_single(circuit.weights) print(f"\n Initial fitness: {initial_fitness:.6f}") if initial_fitness < cfg.fitness_threshold: print(" ERROR: Circuit doesn't pass baseline!") return {} results = {} methods = [ ('magnitude', cfg.run_magnitude, lambda: prune_magnitude(circuit, evaluator, cfg)), ('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)), ('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)), ('exhaustive_mag', cfg.run_exhaustive_mag, lambda: prune_exhaustive_mag(circuit, evaluator, cfg)), ('architecture', cfg.run_architecture, lambda: prune_architecture(circuit, evaluator, cfg)), ('compositional', cfg.run_compositional, lambda: prune_compositional(circuit, evaluator, cfg)), ] enabled = [(n, fn) for n, enabled, fn in methods if enabled] print(f"\n Running {len(enabled)} pruning methods...") print(f"{'=' * 70}") for i, (name, fn) in enumerate(enabled): print(f"\n[{i + 1}/{len(enabled)}] {name.upper()}") print("-" * 50) try: clear_vram() results[name] = fn() r = results[name] print(f" Fitness: {r.fitness:.6f}, Magnitude: {r.final_stats.get('magnitude', 0):.0f}, Time: {r.time_seconds:.1f}s") except Exception as e: print(f" ERROR: {e}") import traceback traceback.print_exc() # Summary print(f"\n{'=' * 70}") print(" SUMMARY") print(f"{'=' * 70}") print(f"\n{'Method':<15} {'Fitness':<10} {'Magnitude':<12} {'Reduction':<12} {'Time':<10}") print("-" * 60) print(f"{'Original':<15} {'1.0000':<10} {original['magnitude']:<12.0f} {'-':<12} {'-':<10}") best_method, best_mag = None, float('inf') for name, r in sorted(results.items(), key=lambda x: x[1].final_stats.get('magnitude', float('inf'))): mag = r.final_stats.get('magnitude', 0) reduction = f"{(1 - mag/original['magnitude'])*100:.1f}%" if mag < original['magnitude'] else "-" print(f"{name:<15} {r.fitness:<10.4f} {mag:<12.0f} {reduction:<12} {r.time_seconds:<10.1f}s") if r.fitness >= cfg.fitness_threshold and mag < best_mag: best_mag, best_method = mag, name if best_method: print(f"\n BEST: {best_method} ({(1 - best_mag/original['magnitude'])*100:.1f}% reduction)") return results def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]: """Find all circuits.""" circuits = [] for d in base.iterdir(): if d.is_dir() and (d / 'config.json').exists() and list(d.glob('*.safetensors')): try: with open(d / 'config.json') as f: cfg = json.load(f) circuits.append(CircuitSpec( name=cfg.get('name', d.name), path=d, inputs=cfg.get('inputs', 0), outputs=cfg.get('outputs', 0), neurons=cfg.get('neurons', 0), layers=cfg.get('layers', 0), parameters=cfg.get('parameters', 0) )) except: pass return sorted(circuits, key=lambda x: (x.inputs, x.neurons)) def main(): parser = argparse.ArgumentParser(description='Threshold Circuit Pruner v5') parser.add_argument('circuit', nargs='?', help='Circuit name') parser.add_argument('--weights', type=str, help='Specific .safetensors file') parser.add_argument('--list', action='store_true') parser.add_argument('--all', action='store_true') parser.add_argument('--max-inputs', type=int, default=10) parser.add_argument('--device', default='cuda') parser.add_argument('--methods', type=str, help='Comma-separated: mag,zero,evo,exh,arch,comp') parser.add_argument('--fitness', type=float, default=0.9999) parser.add_argument('--quiet', action='store_true') parser.add_argument('--save', action='store_true') parser.add_argument('--evo-pop', type=int, default=0) parser.add_argument('--evo-gens', type=int, default=2000) parser.add_argument('--exhaustive-max-params', type=int, default=12) parser.add_argument('--target-mag', type=int, default=-1) parser.add_argument('--arch-hidden', type=int, default=3) parser.add_argument('--arch-max-weight', type=int, default=3) parser.add_argument('--arch-max-mag', type=int, default=20) args = parser.parse_args() if args.list: specs = discover_circuits() print(f"\nAvailable circuits ({len(specs)}):\n") for s in specs: print(f" {s.name:<40} {s.inputs}in/{s.outputs}out {s.neurons}N {s.parameters}P") return vram_cfg = VRAMConfig() cfg = Config( device=args.device, fitness_threshold=args.fitness, verbose=not args.quiet, vram=vram_cfg, evo_pop_size=args.evo_pop, evo_generations=args.evo_gens, exhaustive_max_params=args.exhaustive_max_params, exhaustive_target_mag=args.target_mag, arch_hidden_neurons=args.arch_hidden, arch_max_weight=args.arch_max_weight, arch_max_mag=args.arch_max_mag ) if args.methods: method_map = { 'mag': 'magnitude', 'magnitude': 'magnitude', 'zero': 'zero', 'evo': 'evolutionary', 'evolutionary': 'evolutionary', 'exh': 'exhaustive_mag', 'exh_mag': 'exhaustive_mag', 'exhaustive': 'exhaustive_mag', 'arch': 'architecture', 'architecture': 'architecture', 'comp': 'compositional', 'compositional': 'compositional' } for m in args.methods.lower().split(','): m = m.strip() if m in method_map: setattr(cfg, f'run_{method_map[m]}', True) RESULTS_PATH.mkdir(exist_ok=True) if args.all: specs = [s for s in discover_circuits() if s.inputs <= args.max_inputs] print(f"\nRunning on {len(specs)} circuits...") for spec in specs: try: circuit = AdaptiveCircuit(spec.path, cfg.device) run_all_methods(circuit, cfg) clear_vram() except Exception as e: print(f"ERROR on {spec.name}: {e}") elif args.circuit: path = CIRCUITS_PATH / args.circuit if not path.exists(): path = CIRCUITS_PATH / f'threshold-{args.circuit}' if not path.exists(): print(f"Circuit not found: {args.circuit}") return circuit = AdaptiveCircuit(path, cfg.device, args.weights) results = run_all_methods(circuit, cfg) if args.save and results: best = min(results.values(), key=lambda r: r.final_stats.get('magnitude', float('inf'))) if best.fitness >= cfg.fitness_threshold: path = circuit.save_weights(best.final_weights, f'pruned_{best.method}') print(f"\nSaved to: {path}") else: parser.print_help() print("\n\nExamples:") print(" python prune.py --list") print(" python prune.py threshold-xor --methods evo") print(" python prune.py threshold-xor --methods exh --exhaustive-max-params 20") print(" python prune.py threshold-crc16-mag53 --methods comp") print(" python prune.py --all --max-inputs 8 --methods mag,zero") if __name__ == '__main__': main()