|
|
""" |
|
|
Threshold Circuit Pruner v5 (Refactored) |
|
|
|
|
|
Streamlined pruning framework with 6 core methods: |
|
|
- magnitude: Greedy weight reduction |
|
|
- zero: Try zeroing individual weights |
|
|
- evolutionary: GPU-parallel genetic algorithm |
|
|
- exhaustive_mag: Provably optimal for small circuits |
|
|
- architecture: Search flat 2-layer alternatives |
|
|
- compositional: For circuits built from known-optimal components |
|
|
|
|
|
Usage: |
|
|
python prune.py threshold-xor --methods evo |
|
|
python prune.py threshold-xor --methods exh_mag |
|
|
python prune.py threshold-crc16-mag53 --methods comp |
|
|
python prune.py --list |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import json |
|
|
import time |
|
|
import random |
|
|
import argparse |
|
|
import math |
|
|
import gc |
|
|
import re |
|
|
import importlib.util |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, List, Tuple, Optional, Callable |
|
|
from safetensors.torch import load_file, save_file |
|
|
from collections import defaultdict |
|
|
from functools import lru_cache |
|
|
from itertools import combinations, product |
|
|
import warnings |
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
CIRCUITS_PATH = Path('D:/threshold-logic-circuits') |
|
|
RESULTS_PATH = CIRCUITS_PATH / 'pruned_results' |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VRAMConfig: |
|
|
target_residency: float = 0.75 |
|
|
safety_margin: float = 0.05 |
|
|
|
|
|
def __post_init__(self): |
|
|
self.total_bytes = 0 |
|
|
self.device_name = "CPU" |
|
|
if torch.cuda.is_available(): |
|
|
props = torch.cuda.get_device_properties(0) |
|
|
self.total_bytes = props.total_memory |
|
|
self.device_name = props.name |
|
|
|
|
|
@property |
|
|
def total_gb(self) -> float: |
|
|
return self.total_bytes / 1e9 |
|
|
|
|
|
@property |
|
|
def available_bytes(self) -> int: |
|
|
return int(self.total_bytes * (self.target_residency - self.safety_margin)) |
|
|
|
|
|
def current_usage(self) -> Dict: |
|
|
if not torch.cuda.is_available(): |
|
|
return {'allocated_gb': 0, 'free_gb': 0, 'utilization': 0} |
|
|
allocated = torch.cuda.memory_allocated() |
|
|
return { |
|
|
'allocated_gb': allocated / 1e9, |
|
|
'free_gb': (self.total_bytes - allocated) / 1e9, |
|
|
'utilization': allocated / self.total_bytes if self.total_bytes > 0 else 0 |
|
|
} |
|
|
|
|
|
|
|
|
def clear_vram(): |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
device: str = 'cuda' |
|
|
fitness_threshold: float = 0.9999 |
|
|
verbose: bool = True |
|
|
vram: VRAMConfig = field(default_factory=VRAMConfig) |
|
|
|
|
|
run_magnitude: bool = False |
|
|
run_zero: bool = False |
|
|
run_evolutionary: bool = False |
|
|
run_exhaustive_mag: bool = False |
|
|
run_architecture: bool = False |
|
|
run_compositional: bool = False |
|
|
|
|
|
magnitude_passes: int = 100 |
|
|
exhaustive_max_params: int = 12 |
|
|
exhaustive_target_mag: int = -1 |
|
|
|
|
|
evo_generations: int = 2000 |
|
|
evo_pop_size: int = 0 |
|
|
evo_elite_ratio: float = 0.05 |
|
|
evo_mutation_rate: float = 0.15 |
|
|
evo_mutation_strength: float = 2.0 |
|
|
evo_crossover_rate: float = 0.3 |
|
|
evo_parsimony: float = 0.001 |
|
|
|
|
|
arch_hidden_neurons: int = 3 |
|
|
arch_max_weight: int = 3 |
|
|
arch_max_mag: int = 20 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CircuitSpec: |
|
|
name: str |
|
|
path: Path |
|
|
inputs: int |
|
|
outputs: int |
|
|
neurons: int |
|
|
layers: int |
|
|
parameters: int |
|
|
description: str = "" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PruneResult: |
|
|
method: str |
|
|
original_stats: Dict |
|
|
final_stats: Dict |
|
|
final_weights: Dict[str, torch.Tensor] |
|
|
fitness: float |
|
|
time_seconds: float |
|
|
metadata: Dict = field(default_factory=dict) |
|
|
|
|
|
|
|
|
class ComputationGraph: |
|
|
"""Parses weight structure to build dependency graph.""" |
|
|
|
|
|
def __init__(self, weights: Dict[str, torch.Tensor], n_inputs: int, n_outputs: int, device: str): |
|
|
self.device = device |
|
|
self.n_inputs = n_inputs |
|
|
self.n_outputs = n_outputs |
|
|
self.weights = weights |
|
|
self.neurons = {} |
|
|
self.neuron_order = [] |
|
|
self.output_neurons = [] |
|
|
self.layer_groups = defaultdict(list) |
|
|
self._parse_structure() |
|
|
self._build_execution_order() |
|
|
|
|
|
def _parse_structure(self): |
|
|
neuron_weights = defaultdict(dict) |
|
|
for key, tensor in self.weights.items(): |
|
|
if '.weight' in key: |
|
|
neuron_name = key.replace('.weight', '') |
|
|
neuron_weights[neuron_name]['weight'] = key |
|
|
neuron_weights[neuron_name]['weight_shape'] = tensor.shape |
|
|
elif '.bias' in key: |
|
|
neuron_name = key.replace('.bias', '') |
|
|
neuron_weights[neuron_name]['bias'] = key |
|
|
|
|
|
for neuron_name, params in neuron_weights.items(): |
|
|
if 'weight' in params: |
|
|
depth = self._estimate_depth(neuron_name) |
|
|
self.neurons[neuron_name] = { |
|
|
'weight_key': params.get('weight'), |
|
|
'bias_key': params.get('bias'), |
|
|
'weight_shape': params.get('weight_shape', (1,)), |
|
|
'input_size': params.get('weight_shape', (1,))[-1], |
|
|
'depth': depth |
|
|
} |
|
|
self.layer_groups[depth].append(neuron_name) |
|
|
|
|
|
self._infer_depth_from_shapes() |
|
|
self._identify_outputs() |
|
|
|
|
|
def _estimate_depth(self, name: str) -> int: |
|
|
depth = 0 |
|
|
if 'layer' in name: |
|
|
match = re.search(r'layer(\d+)', name) |
|
|
if match: |
|
|
depth = int(match.group(1)) |
|
|
depth += len(name.split('.')) - 1 |
|
|
return depth |
|
|
|
|
|
def _infer_depth_from_shapes(self): |
|
|
for name, info in self.neurons.items(): |
|
|
if info.get('input_size', 0) == self.n_inputs: |
|
|
info['depth'] = 0 |
|
|
info['input_source'] = 'raw' |
|
|
|
|
|
self.layer_groups = defaultdict(list) |
|
|
for name, info in self.neurons.items(): |
|
|
self.layer_groups[info['depth']].append(name) |
|
|
|
|
|
def _identify_outputs(self): |
|
|
candidates = [] |
|
|
for name in self.neurons: |
|
|
is_parent = any(name + '.' in other for other in self.neurons if other != name) |
|
|
if not is_parent: |
|
|
candidates.append((name, self.neurons[name]['depth'])) |
|
|
candidates.sort(key=lambda x: (-x[1], x[0])) |
|
|
self.output_neurons = [c[0] for c in candidates[:self.n_outputs]] |
|
|
|
|
|
def _build_execution_order(self): |
|
|
self.neuron_order = sorted(self.neurons.keys(), key=lambda n: (self.neurons[n]['depth'], n)) |
|
|
|
|
|
def forward_single(self, inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
activations = {'input': inputs} |
|
|
for neuron_name in self.neuron_order: |
|
|
info = self.neurons[neuron_name] |
|
|
w_key, b_key = info['weight_key'], info['bias_key'] |
|
|
if w_key and w_key in weights: |
|
|
w = weights[w_key] |
|
|
if w.dim() == 1: |
|
|
w = w.unsqueeze(0) |
|
|
inp = self._get_neuron_input(neuron_name, activations, inputs, w.shape[-1]) |
|
|
if inp.dim() == 1: |
|
|
out = inp @ w.flatten() |
|
|
else: |
|
|
out = (inp.unsqueeze(-2) @ w.T.unsqueeze(0)).squeeze(-2) |
|
|
if out.dim() > 1 and out.shape[-1] == 1: |
|
|
out = out.squeeze(-1) |
|
|
if b_key and b_key in weights: |
|
|
out = out + weights[b_key].squeeze() |
|
|
activations[neuron_name] = (out >= 0).float() |
|
|
return self._collect_outputs(activations) |
|
|
|
|
|
def _get_neuron_input(self, neuron_name: str, activations: Dict, raw_input: torch.Tensor, expected_size: int) -> torch.Tensor: |
|
|
info = self.neurons.get(neuron_name, {}) |
|
|
if info.get('input_source') == 'raw' or expected_size == self.n_inputs: |
|
|
return raw_input |
|
|
if 'layer2' in neuron_name or '.out' in neuron_name: |
|
|
base = neuron_name.replace('.layer2', '').replace('.out', '') |
|
|
hidden_keys = [k for k in activations if k.startswith(base) and k != neuron_name and k != 'input'] |
|
|
if len(hidden_keys) == expected_size: |
|
|
return torch.stack([activations[k] for k in sorted(hidden_keys)], dim=-1) |
|
|
return raw_input[..., :expected_size] if raw_input.shape[-1] >= expected_size else raw_input |
|
|
|
|
|
def _collect_outputs(self, activations: Dict) -> torch.Tensor: |
|
|
outputs = [activations[n] for n in sorted(self.output_neurons) if n in activations] |
|
|
if outputs: |
|
|
return torch.stack(outputs, dim=-1) if outputs[0].dim() > 0 else torch.stack(outputs) |
|
|
return torch.zeros(self.n_outputs, device=self.device) |
|
|
|
|
|
|
|
|
class AdaptiveCircuit: |
|
|
"""Adaptive threshold circuit with automatic evaluation.""" |
|
|
|
|
|
def __init__(self, path: Path, device: str = 'cuda', weights_file: str = None): |
|
|
self.path = Path(path) |
|
|
self.device = device |
|
|
self.spec = self._load_spec() |
|
|
self.weights = self._load_weights(weights_file) |
|
|
self.weight_keys = list(self.weights.keys()) |
|
|
self.n_weights = sum(t.numel() for t in self.weights.values()) |
|
|
|
|
|
self.native_forward = self._try_load_native_forward() |
|
|
self.has_native = self.native_forward is not None |
|
|
print(f" [LOAD] Native forward: {'FOUND' if self.has_native else 'NOT FOUND'}") |
|
|
|
|
|
print(f" [LOAD] Parsing circuit topology...") |
|
|
self.graph = ComputationGraph(self.weights, self.spec.inputs, self.spec.outputs, device) |
|
|
print(f" [LOAD] Found {len(self.graph.neurons)} neurons across {len(self.graph.layer_groups)} layers") |
|
|
print(f" [LOAD] Output neurons: {sorted(self.graph.output_neurons)}") |
|
|
|
|
|
print(f" [LOAD] Building test cases...") |
|
|
self.test_inputs, self.test_expected = self._build_tests() |
|
|
self.n_cases = self.test_inputs.shape[0] |
|
|
print(f" [LOAD] Generated {self.n_cases} test cases") |
|
|
|
|
|
self._compile_fast_forward() |
|
|
print(f" [LOAD] Circuit ready: {self.n_weights} weight parameters") |
|
|
|
|
|
def _try_load_native_forward(self) -> Optional[Callable]: |
|
|
model_py = self.path / 'model.py' |
|
|
if not model_py.exists(): |
|
|
return None |
|
|
try: |
|
|
spec = importlib.util.spec_from_file_location("circuit_model", model_py) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(module) |
|
|
if hasattr(module, 'forward'): |
|
|
return module.forward |
|
|
return None |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def evaluate_native(self, inputs: torch.Tensor, weights: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
if not self.has_native: |
|
|
return self.graph.forward_single(inputs, weights) |
|
|
try: |
|
|
out = self.native_forward(inputs, weights) |
|
|
if isinstance(out, torch.Tensor): |
|
|
return out.to(self.device) if out.dim() > 1 else out.unsqueeze(-1).to(self.device) |
|
|
except Exception: |
|
|
pass |
|
|
cpu_weights = {k: v.cpu() for k, v in weights.items()} |
|
|
results = [] |
|
|
for i in range(inputs.shape[0]): |
|
|
inp = [int(x) for x in inputs[i].cpu().tolist()] |
|
|
try: |
|
|
out = self.native_forward(inp, cpu_weights) |
|
|
results.append([float(x) for x in out] if isinstance(out, (list, tuple)) else [float(out)]) |
|
|
except Exception: |
|
|
results.append([0.0] * self.spec.outputs) |
|
|
return torch.tensor(results, device=self.device, dtype=torch.float32) |
|
|
|
|
|
def _load_spec(self) -> CircuitSpec: |
|
|
with open(self.path / 'config.json') as f: |
|
|
cfg = json.load(f) |
|
|
return CircuitSpec( |
|
|
name=cfg.get('name', self.path.name), |
|
|
path=self.path, |
|
|
inputs=cfg.get('inputs', cfg.get('input_size', 0)), |
|
|
outputs=cfg.get('outputs', cfg.get('output_size', 0)), |
|
|
neurons=cfg.get('neurons', 0), |
|
|
layers=cfg.get('layers', 0), |
|
|
parameters=cfg.get('parameters', 0), |
|
|
description=cfg.get('description', '') |
|
|
) |
|
|
|
|
|
def _load_weights(self, weights_file: str = None) -> Dict[str, torch.Tensor]: |
|
|
if weights_file: |
|
|
sf = self.path / weights_file |
|
|
else: |
|
|
sf = self.path / 'model.safetensors' |
|
|
if not sf.exists(): |
|
|
candidates = list(self.path.glob('*.safetensors')) |
|
|
sf = candidates[0] if candidates else sf |
|
|
w = load_file(str(sf)) |
|
|
return {k: v.float().to(self.device) for k, v in w.items()} |
|
|
|
|
|
def _build_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
if self.has_native: |
|
|
return self._build_native_tests() |
|
|
return self._build_exhaustive_tests() |
|
|
|
|
|
def _build_exhaustive_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
n = self.spec.inputs |
|
|
if n > 24: |
|
|
raise ValueError(f"Input space too large: 2^{n}") |
|
|
n_cases = 2 ** n |
|
|
idx = torch.arange(n_cases, device=self.device, dtype=torch.long) |
|
|
bits = torch.arange(n, device=self.device, dtype=torch.long) |
|
|
inputs = ((idx.unsqueeze(1) >> bits) & 1).float() |
|
|
expected = self.graph.forward_single(inputs, self.weights) |
|
|
return inputs, expected |
|
|
|
|
|
def _build_native_tests(self) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
n = self.spec.inputs |
|
|
if n > 20: |
|
|
raise ValueError(f"Input space too large: 2^{n}") |
|
|
n_cases = 2 ** n |
|
|
inputs_list, expected_list = [], [] |
|
|
cpu_weights = {k: v.cpu() for k, v in self.weights.items()} |
|
|
for i in range(n_cases): |
|
|
inp = [(i >> b) & 1 for b in range(n)] |
|
|
inputs_list.append(inp) |
|
|
out = self.native_forward(inp, cpu_weights) |
|
|
expected_list.append([float(x) for x in out] if isinstance(out, (list, tuple)) else [float(out)]) |
|
|
return (torch.tensor(inputs_list, device=self.device, dtype=torch.float32), |
|
|
torch.tensor(expected_list, device=self.device, dtype=torch.float32)) |
|
|
|
|
|
def _compile_fast_forward(self): |
|
|
self.weight_layout = [] |
|
|
offset = 0 |
|
|
for key in self.weight_keys: |
|
|
size = self.weights[key].numel() |
|
|
self.weight_layout.append((key, offset, offset + size, self.weights[key].shape)) |
|
|
offset += size |
|
|
self.base_vector = self.weights_to_vector(self.weights) |
|
|
|
|
|
def weights_to_vector(self, weights: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
|
return torch.cat([weights[k].flatten() for k in self.weight_keys]) |
|
|
|
|
|
def vector_to_weights(self, vector: torch.Tensor) -> Dict[str, torch.Tensor]: |
|
|
weights = {} |
|
|
for key, start, end, shape in self.weight_layout: |
|
|
weights[key] = vector[start:end].view(shape) |
|
|
return weights |
|
|
|
|
|
def clone_weights(self) -> Dict[str, torch.Tensor]: |
|
|
return {k: v.clone() for k, v in self.weights.items()} |
|
|
|
|
|
def stats(self, weights: Dict[str, torch.Tensor] = None) -> Dict: |
|
|
w = weights or self.weights |
|
|
total = sum(t.numel() for t in w.values()) |
|
|
nonzero = sum((t != 0).sum().item() for t in w.values()) |
|
|
mag = sum(t.abs().sum().item() for t in w.values()) |
|
|
maxw = max(t.abs().max().item() for t in w.values()) if w else 0 |
|
|
return { |
|
|
'total': total, |
|
|
'nonzero': nonzero, |
|
|
'sparsity': 1 - nonzero / total if total else 0, |
|
|
'magnitude': mag, |
|
|
'max_weight': maxw |
|
|
} |
|
|
|
|
|
def save_weights(self, weights: Dict[str, torch.Tensor], suffix: str = 'pruned') -> Path: |
|
|
path = self.path / f'model_{suffix}.safetensors' |
|
|
save_file({k: v.cpu() for k, v in weights.items()}, str(path)) |
|
|
return path |
|
|
|
|
|
|
|
|
class BatchedEvaluator: |
|
|
"""GPU-optimized batched population evaluation.""" |
|
|
|
|
|
def __init__(self, circuit: AdaptiveCircuit, cfg: Config): |
|
|
self.circuit = circuit |
|
|
self.cfg = cfg |
|
|
self.device = cfg.device |
|
|
self.test_inputs = circuit.test_inputs |
|
|
self.test_expected = circuit.test_expected |
|
|
self.n_cases = circuit.n_cases |
|
|
self.n_weights = circuit.n_weights |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [EVAL] Initializing evaluator...") |
|
|
|
|
|
self._calculate_batch_size() |
|
|
self._validate_evaluation() |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [EVAL] Evaluator ready: batch={self.max_batch:,}, native={circuit.has_native}") |
|
|
|
|
|
def _calculate_batch_size(self): |
|
|
bytes_per_ind = self.n_weights * 4 * 2 + self.n_cases * self.circuit.spec.outputs * 4 + 4096 |
|
|
available = self.cfg.vram.available_bytes |
|
|
self.max_batch = max(1000, min(available // max(bytes_per_ind, 1), 5_000_000)) |
|
|
|
|
|
def _validate_evaluation(self): |
|
|
fitness = self.evaluate_single(self.circuit.weights) |
|
|
if fitness < 0.999 and self.cfg.verbose: |
|
|
print(f" [EVAL WARNING] Original weights fitness={fitness:.4f}") |
|
|
|
|
|
def evaluate_single(self, weights: Dict[str, torch.Tensor]) -> float: |
|
|
with torch.no_grad(): |
|
|
if self.circuit.has_native: |
|
|
outputs = self.circuit.evaluate_native(self.test_inputs, weights) |
|
|
else: |
|
|
outputs = self.circuit.graph.forward_single(self.test_inputs, weights) |
|
|
if outputs.shape != self.test_expected.shape: |
|
|
if outputs.dim() == 1: |
|
|
outputs = outputs.unsqueeze(0).expand(self.test_expected.shape[0], -1) |
|
|
correct = (outputs == self.test_expected).all(dim=-1).float().sum() |
|
|
return (correct / self.n_cases).item() |
|
|
|
|
|
def evaluate_population(self, population: torch.Tensor) -> torch.Tensor: |
|
|
pop_size = population.shape[0] |
|
|
if pop_size > self.max_batch: |
|
|
return self._evaluate_chunked(population) |
|
|
return self._evaluate_sequential(population) |
|
|
|
|
|
def _evaluate_sequential(self, population: torch.Tensor) -> torch.Tensor: |
|
|
pop_size = population.shape[0] |
|
|
fitness = torch.zeros(pop_size, device=self.device) |
|
|
with torch.no_grad(): |
|
|
for i in range(pop_size): |
|
|
weights = self.circuit.vector_to_weights(population[i]) |
|
|
if self.circuit.has_native: |
|
|
outputs = self.circuit.evaluate_native(self.test_inputs, weights) |
|
|
else: |
|
|
outputs = self.circuit.graph.forward_single(self.test_inputs, weights) |
|
|
if outputs.shape != self.test_expected.shape: |
|
|
if outputs.dim() == 1: |
|
|
outputs = outputs.unsqueeze(0).expand(self.test_expected.shape[0], -1) |
|
|
if outputs.shape == self.test_expected.shape: |
|
|
correct = (outputs == self.test_expected).all(dim=-1).float().sum() |
|
|
fitness[i] = correct / self.n_cases |
|
|
return fitness |
|
|
|
|
|
def _evaluate_chunked(self, population: torch.Tensor) -> torch.Tensor: |
|
|
pop_size = population.shape[0] |
|
|
fitness = torch.zeros(pop_size, device=self.device) |
|
|
for start in range(0, pop_size, self.max_batch): |
|
|
end = min(start + self.max_batch, pop_size) |
|
|
fitness[start:end] = self._evaluate_sequential(population[start:end]) |
|
|
return fitness |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prune_magnitude(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
|
|
"""Iterative magnitude reduction - try reducing each weight toward zero.""" |
|
|
start = time.perf_counter() |
|
|
weights = circuit.clone_weights() |
|
|
original = circuit.stats(weights) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" Starting magnitude reduction (passes={cfg.magnitude_passes})...") |
|
|
|
|
|
for pass_num in range(cfg.magnitude_passes): |
|
|
candidates = [] |
|
|
for name, tensor in weights.items(): |
|
|
flat = tensor.flatten() |
|
|
for i in range(len(flat)): |
|
|
val = flat[i].item() |
|
|
if val != 0: |
|
|
new_val = val - 1 if val > 0 else val + 1 |
|
|
candidates.append((name, i, tensor.shape, val, new_val)) |
|
|
|
|
|
if not candidates: |
|
|
break |
|
|
|
|
|
random.shuffle(candidates) |
|
|
reductions = 0 |
|
|
|
|
|
for name, idx, shape, old_val, new_val in candidates: |
|
|
flat = weights[name].flatten() |
|
|
flat[idx] = new_val |
|
|
weights[name] = flat.view(shape) |
|
|
if evaluator.evaluate_single(weights) >= cfg.fitness_threshold: |
|
|
reductions += 1 |
|
|
else: |
|
|
flat[idx] = old_val |
|
|
weights[name] = flat.view(shape) |
|
|
|
|
|
if cfg.verbose: |
|
|
stats = circuit.stats(weights) |
|
|
print(f" Pass {pass_num}: {reductions} reductions, mag={stats['magnitude']:.0f}") |
|
|
|
|
|
if reductions == 0: |
|
|
break |
|
|
|
|
|
return PruneResult( |
|
|
method='magnitude', |
|
|
original_stats=original, |
|
|
final_stats=circuit.stats(weights), |
|
|
final_weights=weights, |
|
|
fitness=evaluator.evaluate_single(weights), |
|
|
time_seconds=time.perf_counter() - start |
|
|
) |
|
|
|
|
|
|
|
|
def prune_zero(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
|
|
"""Zero pruning - try setting each non-zero weight to zero.""" |
|
|
start = time.perf_counter() |
|
|
weights = circuit.clone_weights() |
|
|
original = circuit.stats(weights) |
|
|
|
|
|
candidates = [] |
|
|
for name, tensor in weights.items(): |
|
|
flat = tensor.flatten() |
|
|
for i in range(len(flat)): |
|
|
if flat[i].item() != 0: |
|
|
candidates.append((name, i, tensor.shape, flat[i].item())) |
|
|
|
|
|
random.shuffle(candidates) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" Testing {len(candidates)} non-zero weights for zeroing...") |
|
|
|
|
|
zeroed = 0 |
|
|
for name, idx, shape, old_val in candidates: |
|
|
flat = weights[name].flatten() |
|
|
flat[idx] = 0 |
|
|
weights[name] = flat.view(shape) |
|
|
if evaluator.evaluate_single(weights) >= cfg.fitness_threshold: |
|
|
zeroed += 1 |
|
|
else: |
|
|
flat[idx] = old_val |
|
|
weights[name] = flat.view(shape) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" Zeroed {zeroed}/{len(candidates)} weights") |
|
|
|
|
|
return PruneResult( |
|
|
method='zero', |
|
|
original_stats=original, |
|
|
final_stats=circuit.stats(weights), |
|
|
final_weights=weights, |
|
|
fitness=evaluator.evaluate_single(weights), |
|
|
time_seconds=time.perf_counter() - start |
|
|
) |
|
|
|
|
|
|
|
|
def prune_evolutionary(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
|
|
"""Evolutionary search with GPU-optimized parallel population evaluation.""" |
|
|
start = time.perf_counter() |
|
|
original = circuit.stats() |
|
|
|
|
|
if cfg.evo_pop_size > 0: |
|
|
pop_size = cfg.evo_pop_size |
|
|
else: |
|
|
bytes_per_ind = circuit.n_weights * 4 * 3 |
|
|
available = cfg.vram.available_bytes - torch.cuda.memory_allocated() if torch.cuda.is_available() else cfg.vram.available_bytes |
|
|
pop_size = max(10000, min(available // max(bytes_per_ind, 1), evaluator.max_batch, 500000)) |
|
|
|
|
|
elite_size = max(1, int(pop_size * cfg.evo_elite_ratio)) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [EVO] Population: {pop_size:,}, Elite: {elite_size:,}, Generations: {cfg.evo_generations}") |
|
|
|
|
|
base_vector = circuit.weights_to_vector(circuit.weights) |
|
|
population = base_vector.unsqueeze(0).expand(pop_size, -1).clone() |
|
|
|
|
|
|
|
|
n_exact = max(elite_size, pop_size // 10) |
|
|
for i in range(n_exact, pop_size // 2): |
|
|
n_muts = max(1, circuit.n_weights // 10) |
|
|
mut_idx = torch.randperm(circuit.n_weights)[:n_muts] |
|
|
population[i, mut_idx] += torch.randint(-1, 2, (n_muts,), device=cfg.device, dtype=population.dtype) |
|
|
for i in range(pop_size // 2, pop_size): |
|
|
n_muts = max(1, circuit.n_weights // 4) |
|
|
mut_idx = torch.randperm(circuit.n_weights)[:n_muts] |
|
|
population[i, mut_idx] += torch.randint(-2, 3, (n_muts,), device=cfg.device, dtype=population.dtype) |
|
|
|
|
|
best_weights = circuit.clone_weights() |
|
|
best_fitness = evaluator.evaluate_single(best_weights) |
|
|
best_mag = original['magnitude'] |
|
|
best_score = best_fitness - cfg.evo_parsimony * best_mag / circuit.n_weights if best_fitness >= cfg.fitness_threshold else -float('inf') |
|
|
stagnant = 0 |
|
|
mutation_rate = cfg.evo_mutation_rate |
|
|
|
|
|
for gen in range(cfg.evo_generations): |
|
|
fitness = evaluator.evaluate_population(population) |
|
|
magnitudes = population.abs().sum(dim=1) |
|
|
adjusted = fitness - cfg.evo_parsimony * magnitudes / circuit.n_weights |
|
|
|
|
|
valid_mask = fitness >= cfg.fitness_threshold |
|
|
n_valid = valid_mask.sum().item() |
|
|
|
|
|
if n_valid > 0: |
|
|
valid_adjusted = adjusted.clone() |
|
|
valid_adjusted[~valid_mask] = -float('inf') |
|
|
best_idx = valid_adjusted.argmax().item() |
|
|
if adjusted[best_idx] > best_score: |
|
|
best_score = adjusted[best_idx].item() |
|
|
best_fitness = fitness[best_idx].item() |
|
|
best_weights = circuit.vector_to_weights(population[best_idx].clone()) |
|
|
best_mag = magnitudes[best_idx].item() |
|
|
stagnant = 0 |
|
|
else: |
|
|
stagnant += 1 |
|
|
else: |
|
|
stagnant += 1 |
|
|
|
|
|
|
|
|
if stagnant > 50: |
|
|
mutation_rate = min(0.5, mutation_rate * 1.1) |
|
|
elif stagnant == 0: |
|
|
mutation_rate = max(0.01, mutation_rate * 0.95) |
|
|
|
|
|
if cfg.verbose and (gen % 50 == 0 or gen == cfg.evo_generations - 1): |
|
|
print(f" Gen {gen:4d} | valid: {n_valid:6,}/{pop_size:,} | mag: {best_mag:.0f} | stag: {stagnant}") |
|
|
|
|
|
|
|
|
sorted_idx = adjusted.argsort(descending=True) |
|
|
elite = population[sorted_idx[:elite_size]].clone() |
|
|
probs = F.softmax(adjusted * 10, dim=0) |
|
|
parent_idx = torch.multinomial(probs, pop_size - elite_size, replacement=True) |
|
|
children = population[parent_idx].clone() |
|
|
|
|
|
|
|
|
if cfg.evo_crossover_rate > 0: |
|
|
cross_mask = torch.rand(len(children), device=cfg.device) < cfg.evo_crossover_rate |
|
|
cross_idx = torch.where(cross_mask)[0] |
|
|
for i in range(0, len(cross_idx) - 1, 2): |
|
|
p1, p2 = cross_idx[i].item(), cross_idx[i + 1].item() |
|
|
point = random.randint(1, circuit.n_weights - 1) |
|
|
temp = children[p1, point:].clone() |
|
|
children[p1, point:] = children[p2, point:] |
|
|
children[p2, point:] = temp |
|
|
|
|
|
|
|
|
mut_mask = torch.rand_like(children) < mutation_rate |
|
|
mutations = torch.randint(-int(cfg.evo_mutation_strength), int(cfg.evo_mutation_strength) + 1, |
|
|
children.shape, device=cfg.device, dtype=children.dtype) |
|
|
children = children + mut_mask * mutations |
|
|
population = torch.cat([elite, children], dim=0) |
|
|
|
|
|
if stagnant > 200: |
|
|
if cfg.verbose: |
|
|
print(f" [EVO] Early stop at generation {gen}") |
|
|
break |
|
|
|
|
|
if cfg.verbose: |
|
|
final_stats = circuit.stats(best_weights) |
|
|
print(f" [EVO] Final: mag={final_stats['magnitude']:.0f} (was {original['magnitude']:.0f})") |
|
|
|
|
|
return PruneResult( |
|
|
method='evolutionary', |
|
|
original_stats=original, |
|
|
final_stats=circuit.stats(best_weights), |
|
|
final_weights=best_weights, |
|
|
fitness=best_fitness, |
|
|
time_seconds=time.perf_counter() - start, |
|
|
metadata={'generations': gen + 1, 'population_size': pop_size} |
|
|
) |
|
|
|
|
|
|
|
|
def _partitions(total: int, n: int, max_val: int): |
|
|
"""Generate all ways to partition 'total' into 'n' non-negative integers <= max_val.""" |
|
|
if n == 0: |
|
|
if total == 0: |
|
|
yield [] |
|
|
return |
|
|
for i in range(min(total, max_val) + 1): |
|
|
for rest in _partitions(total - i, n - 1, max_val): |
|
|
yield [i] + rest |
|
|
|
|
|
|
|
|
def _all_signs(abs_vals: list): |
|
|
"""Generate all sign combinations for absolute values.""" |
|
|
if not abs_vals: |
|
|
yield [] |
|
|
return |
|
|
for rest in _all_signs(abs_vals[1:]): |
|
|
if abs_vals[0] == 0: |
|
|
yield [0] + rest |
|
|
else: |
|
|
yield [abs_vals[0]] + rest |
|
|
yield [-abs_vals[0]] + rest |
|
|
|
|
|
|
|
|
def _configs_at_magnitude(mag: int, n_params: int): |
|
|
"""Generate all n_params-length configs with given total magnitude.""" |
|
|
for partition in _partitions(mag, n_params, mag): |
|
|
for signed in _all_signs(partition): |
|
|
yield tuple(signed) |
|
|
|
|
|
|
|
|
def prune_exhaustive_mag(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
|
|
"""Exhaustive search by magnitude - finds provably optimal solutions.""" |
|
|
start = time.perf_counter() |
|
|
original = circuit.stats() |
|
|
|
|
|
n_params = original['total'] |
|
|
max_mag = int(original['magnitude']) |
|
|
target_mag = cfg.exhaustive_target_mag |
|
|
|
|
|
if n_params > cfg.exhaustive_max_params: |
|
|
if cfg.verbose: |
|
|
print(f" [EXHAUSTIVE] Skipping: {n_params} params > max {cfg.exhaustive_max_params}") |
|
|
return PruneResult( |
|
|
method='exhaustive_mag', |
|
|
original_stats=original, |
|
|
final_stats=original, |
|
|
final_weights=circuit.clone_weights(), |
|
|
fitness=evaluator.evaluate_single(circuit.weights), |
|
|
time_seconds=time.perf_counter() - start, |
|
|
metadata={'skipped': True} |
|
|
) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [EXHAUSTIVE] Parameters: {n_params}, Original magnitude: {max_mag}") |
|
|
if target_mag >= 0: |
|
|
print(f" [EXHAUSTIVE] Target magnitude: {target_mag}") |
|
|
|
|
|
weight_keys = list(circuit.weights.keys()) |
|
|
weight_shapes = {k: circuit.weights[k].shape for k in weight_keys} |
|
|
weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys} |
|
|
|
|
|
def vector_to_weights(vec): |
|
|
weights = {} |
|
|
idx = 0 |
|
|
for k in weight_keys: |
|
|
size = weight_sizes[k] |
|
|
weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k]) |
|
|
idx += size |
|
|
return weights |
|
|
|
|
|
all_solutions = [] |
|
|
optimal_mag = None |
|
|
total_tested = 0 |
|
|
|
|
|
mag_range = [target_mag] if target_mag >= 0 else range(0, max_mag + 1) |
|
|
|
|
|
for mag in mag_range: |
|
|
configs = list(_configs_at_magnitude(mag, n_params)) |
|
|
if not configs: |
|
|
continue |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" Magnitude {mag}: {len(configs):,} configurations...", end=" ", flush=True) |
|
|
|
|
|
valid = [] |
|
|
batch_size = min(100000, len(configs)) |
|
|
|
|
|
for batch_start in range(0, len(configs), batch_size): |
|
|
batch = configs[batch_start:batch_start + batch_size] |
|
|
population = torch.tensor(batch, dtype=torch.float32, device=cfg.device) |
|
|
try: |
|
|
fitness = evaluator.evaluate_population(population) |
|
|
except: |
|
|
fitness = torch.tensor([evaluator.evaluate_single(vector_to_weights(c)) for c in batch], device=cfg.device) |
|
|
for i, is_valid in enumerate((fitness >= cfg.fitness_threshold).tolist()): |
|
|
if is_valid: |
|
|
valid.append(batch[i]) |
|
|
|
|
|
total_tested += len(configs) |
|
|
|
|
|
if valid: |
|
|
if cfg.verbose: |
|
|
print(f"FOUND {len(valid)} solutions!") |
|
|
optimal_mag = mag |
|
|
all_solutions = valid |
|
|
|
|
|
if cfg.verbose and len(valid) <= 50: |
|
|
print(f" Solutions:") |
|
|
for i, sol in enumerate(valid[:20]): |
|
|
nz = sum(1 for v in sol if v != 0) |
|
|
print(f" {i+1}: mag={sum(abs(v) for v in sol)}, nz={nz}, {sol}") |
|
|
break |
|
|
else: |
|
|
if cfg.verbose: |
|
|
print("none") |
|
|
|
|
|
if all_solutions: |
|
|
best_weights = vector_to_weights(all_solutions[0]) |
|
|
best_fitness = evaluator.evaluate_single(best_weights) |
|
|
else: |
|
|
best_weights = circuit.clone_weights() |
|
|
best_fitness = evaluator.evaluate_single(best_weights) |
|
|
optimal_mag = max_mag |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [EXHAUSTIVE] Tested: {total_tested:,}, Optimal: {optimal_mag}, Solutions: {len(all_solutions)}") |
|
|
|
|
|
return PruneResult( |
|
|
method='exhaustive_mag', |
|
|
original_stats=original, |
|
|
final_stats=circuit.stats(best_weights), |
|
|
final_weights=best_weights, |
|
|
fitness=best_fitness, |
|
|
time_seconds=time.perf_counter() - start, |
|
|
metadata={'optimal_magnitude': optimal_mag, 'solutions_count': len(all_solutions), 'all_solutions': all_solutions[:100]} |
|
|
) |
|
|
|
|
|
|
|
|
def prune_architecture(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
|
|
"""Architecture search - find optimal flat 2-layer architecture.""" |
|
|
start = time.perf_counter() |
|
|
original = circuit.stats() |
|
|
|
|
|
n_hidden = cfg.arch_hidden_neurons |
|
|
n_inputs = circuit.spec.inputs |
|
|
n_outputs = circuit.spec.outputs |
|
|
max_weight = cfg.arch_max_weight |
|
|
max_mag = cfg.arch_max_mag |
|
|
|
|
|
n_params = n_hidden * (n_inputs + 1) + n_outputs * (n_hidden + 1) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [ARCH] Hidden: {n_hidden}, Params: {n_params}, Max magnitude: {max_mag}") |
|
|
|
|
|
test_inputs = circuit.test_inputs |
|
|
test_expected = circuit.test_expected |
|
|
|
|
|
def eval_flat(configs: torch.Tensor) -> torch.Tensor: |
|
|
batch_size = configs.shape[0] |
|
|
idx = 0 |
|
|
hidden_w, hidden_b = [], [] |
|
|
for _ in range(n_hidden): |
|
|
hidden_w.append(configs[:, idx:idx+n_inputs]) |
|
|
idx += n_inputs |
|
|
hidden_b.append(configs[:, idx:idx+1]) |
|
|
idx += 1 |
|
|
output_w, output_b = [], [] |
|
|
for _ in range(n_outputs): |
|
|
output_w.append(configs[:, idx:idx+n_hidden]) |
|
|
idx += n_hidden |
|
|
output_b.append(configs[:, idx:idx+1]) |
|
|
idx += 1 |
|
|
|
|
|
hidden_acts = [] |
|
|
for h in range(n_hidden): |
|
|
act = (hidden_w[h].unsqueeze(1) * test_inputs.unsqueeze(0)).sum(dim=2) + hidden_b[h] |
|
|
hidden_acts.append((act >= 0).float()) |
|
|
hidden_stack = torch.stack(hidden_acts, dim=2) |
|
|
|
|
|
outputs = [] |
|
|
for o in range(n_outputs): |
|
|
out = (hidden_stack * output_w[o].unsqueeze(1)).sum(dim=2) + output_b[o] |
|
|
outputs.append((out >= 0).float()) |
|
|
|
|
|
if n_outputs == 1: |
|
|
predicted = outputs[0] |
|
|
expected = test_expected.squeeze() |
|
|
else: |
|
|
predicted = torch.stack(outputs, dim=2) |
|
|
expected = test_expected |
|
|
|
|
|
correct = (predicted == expected.unsqueeze(0)).float().mean(dim=1) |
|
|
if n_outputs > 1: |
|
|
correct = correct.mean(dim=1) |
|
|
return correct |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
|
def partitions(total: int, n_slots: int, max_val: int) -> list: |
|
|
if n_slots == 0: |
|
|
return [()] if total == 0 else [] |
|
|
if n_slots == 1: |
|
|
return [(total,)] if total <= max_val else [] |
|
|
result = [] |
|
|
for v in range(min(total, max_val) + 1): |
|
|
for rest in partitions(total - v, n_slots - 1, max_val): |
|
|
result.append((v,) + rest) |
|
|
return result |
|
|
|
|
|
def signs_for_partition(partition: tuple) -> torch.Tensor: |
|
|
n = len(partition) |
|
|
nonzero_idx = [i for i, v in enumerate(partition) if v != 0] |
|
|
k = len(nonzero_idx) |
|
|
if k == 0: |
|
|
return torch.zeros(1, n, device=cfg.device, dtype=torch.float32) |
|
|
n_patterns = 2 ** k |
|
|
configs = torch.zeros(n_patterns, n, device=cfg.device, dtype=torch.float32) |
|
|
for i, idx in enumerate(nonzero_idx): |
|
|
signs = ((torch.arange(n_patterns, device=cfg.device) >> i) & 1) * 2 - 1 |
|
|
configs[:, idx] = signs.float() * partition[idx] |
|
|
return configs |
|
|
|
|
|
all_solutions = [] |
|
|
optimal_mag = None |
|
|
total_tested = 0 |
|
|
|
|
|
for target_mag in range(1, max_mag + 1): |
|
|
all_configs = [] |
|
|
for partition in partitions(target_mag, n_params, max_weight): |
|
|
all_configs.append(signs_for_partition(partition)) |
|
|
if not all_configs: |
|
|
continue |
|
|
configs = torch.cat(all_configs, dim=0) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" Magnitude {target_mag}: {configs.shape[0]:,} configs...", end=" ", flush=True) |
|
|
|
|
|
valid = [] |
|
|
for i in range(0, configs.shape[0], 500000): |
|
|
batch = configs[i:i+500000] |
|
|
fitness = eval_flat(batch) |
|
|
valid.extend(batch[fitness >= cfg.fitness_threshold].cpu().tolist()) |
|
|
|
|
|
total_tested += configs.shape[0] |
|
|
|
|
|
if valid: |
|
|
if cfg.verbose: |
|
|
print(f"FOUND {len(valid)} solutions!") |
|
|
optimal_mag = target_mag |
|
|
all_solutions = valid |
|
|
break |
|
|
else: |
|
|
if cfg.verbose: |
|
|
print("none") |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [ARCH] Tested: {total_tested:,}, Optimal: {optimal_mag}, Solutions: {len(all_solutions)}") |
|
|
|
|
|
return PruneResult( |
|
|
method='architecture', |
|
|
original_stats=original, |
|
|
final_stats=original, |
|
|
final_weights=circuit.clone_weights(), |
|
|
fitness=evaluator.evaluate_single(circuit.weights), |
|
|
time_seconds=time.perf_counter() - start, |
|
|
metadata={'optimal_magnitude': optimal_mag, 'solutions_count': len(all_solutions)} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OPTIMAL_COMPONENTS = { |
|
|
'xor': { |
|
|
'inputs': 2, |
|
|
'outputs': 1, |
|
|
'neurons': 3, |
|
|
'magnitude': 7, |
|
|
'solutions': [ |
|
|
|
|
|
|
|
|
{'h1': ([-1, 1], 0), 'h2': ([1, -1], 0), 'out': ([-1, -1], 1)}, |
|
|
{'h1': ([1, -1], 0), 'h2': ([-1, 1], 0), 'out': ([-1, -1], 1)}, |
|
|
{'h1': ([-1, 1], 0), 'h2': ([-1, 1], 0), 'out': ([1, -1], 0)}, |
|
|
{'h1': ([1, -1], 0), 'h2': ([1, -1], 0), 'out': ([-1, 1], 0)}, |
|
|
{'h1': ([1, 1], -1), 'h2': ([-1, -1], 1), 'out': ([1, 1], -1)}, |
|
|
{'h1': ([-1, -1], 1), 'h2': ([1, 1], -1), 'out': ([1, 1], -1)}, |
|
|
] |
|
|
}, |
|
|
'xor3': { |
|
|
'inputs': 3, |
|
|
'outputs': 1, |
|
|
'neurons': 4, |
|
|
'magnitude': 10, |
|
|
'solutions': [ |
|
|
|
|
|
{'h1': ([0, 0, -1], 0), 'h2': ([-1, 1, -1], 0), 'h3': ([-1, -1, 1], 0), 'out': ([1, -1, -1], 0)}, |
|
|
{'h1': ([0, 0, 1], -1), 'h2': ([1, -1, 1], -1), 'h3': ([1, 1, -1], -1), 'out': ([-1, 1, 1], 0)}, |
|
|
{'h1': ([0, -1, 0], 0), 'h2': ([-1, -1, 1], 0), 'h3': ([1, -1, -1], 0), 'out': ([-1, -1, 1], 0)}, |
|
|
{'h1': ([0, 1, 0], -1), 'h2': ([1, 1, -1], -1), 'h3': ([-1, 1, 1], -1), 'out': ([1, 1, -1], 0)}, |
|
|
{'h1': ([-1, 0, 0], 0), 'h2': ([-1, 1, -1], 0), 'h3': ([1, -1, -1], 0), 'out': ([-1, 1, -1], 0)}, |
|
|
{'h1': ([1, 0, 0], -1), 'h2': ([1, -1, 1], -1), 'h3': ([-1, 1, 1], -1), 'out': ([1, -1, 1], 0)}, |
|
|
] |
|
|
}, |
|
|
'passthrough': { |
|
|
'inputs': 1, |
|
|
'outputs': 1, |
|
|
'neurons': 1, |
|
|
'magnitude': 2, |
|
|
'solutions': [ |
|
|
{'out': ([1], 0)}, |
|
|
{'out': ([2], -1)}, |
|
|
] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def prune_compositional(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult: |
|
|
""" |
|
|
Compositional search for circuits built from known-optimal components. |
|
|
|
|
|
Instead of searching 10^15 parameter combinations, recognizes that circuits |
|
|
like CRC-16 are composed of XOR (6 solutions), XOR3 (18 solutions), and |
|
|
pass-throughs, giving only 6 × 18 × 18 = 1,944 combinations. |
|
|
""" |
|
|
start = time.perf_counter() |
|
|
original = circuit.stats() |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [COMP] Analyzing circuit structure...") |
|
|
|
|
|
|
|
|
components = [] |
|
|
neuron_names = list(circuit.graph.neurons.keys()) |
|
|
|
|
|
|
|
|
prefixes = set() |
|
|
for name in neuron_names: |
|
|
if '.' in name: |
|
|
prefix = name.rsplit('.', 1)[0] |
|
|
prefixes.add(prefix) |
|
|
else: |
|
|
prefixes.add(name) |
|
|
|
|
|
|
|
|
for prefix in sorted(prefixes): |
|
|
related = [n for n in neuron_names if n == prefix or n.startswith(prefix + '.')] |
|
|
n_neurons = len(related) |
|
|
|
|
|
if n_neurons == 1: |
|
|
|
|
|
components.append({'type': 'passthrough', 'prefix': prefix, 'neurons': related}) |
|
|
elif n_neurons == 3 and any('h1' in n or 'h2' in n for n in related): |
|
|
|
|
|
components.append({'type': 'xor', 'prefix': prefix, 'neurons': related}) |
|
|
elif n_neurons == 4 and any('h1' in n or 'h2' in n or 'h3' in n for n in related): |
|
|
|
|
|
components.append({'type': 'xor3', 'prefix': prefix, 'neurons': related}) |
|
|
else: |
|
|
|
|
|
components.append({'type': 'unknown', 'prefix': prefix, 'neurons': related, 'n': n_neurons}) |
|
|
|
|
|
|
|
|
type_counts = defaultdict(int) |
|
|
for c in components: |
|
|
type_counts[c['type']] += 1 |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [COMP] Detected components:") |
|
|
for ctype, count in sorted(type_counts.items()): |
|
|
if ctype in OPTIMAL_COMPONENTS: |
|
|
n_solutions = len(OPTIMAL_COMPONENTS[ctype]['solutions']) |
|
|
print(f" - {ctype}: {count} instances × {n_solutions} solutions each") |
|
|
else: |
|
|
print(f" - {ctype}: {count} instances (no known optimal)") |
|
|
|
|
|
|
|
|
unknown_components = [c for c in components if c['type'] == 'unknown'] |
|
|
if unknown_components: |
|
|
if cfg.verbose: |
|
|
print(f" [COMP] Cannot use compositional search - unknown components found") |
|
|
for c in unknown_components[:5]: |
|
|
print(f" - {c['prefix']}: {c['n']} neurons") |
|
|
return PruneResult( |
|
|
method='compositional', |
|
|
original_stats=original, |
|
|
final_stats=original, |
|
|
final_weights=circuit.clone_weights(), |
|
|
fitness=evaluator.evaluate_single(circuit.weights), |
|
|
time_seconds=time.perf_counter() - start, |
|
|
metadata={'status': 'unknown_components', 'unknown': [c['prefix'] for c in unknown_components]} |
|
|
) |
|
|
|
|
|
|
|
|
total_combos = 1 |
|
|
for c in components: |
|
|
if c['type'] in OPTIMAL_COMPONENTS: |
|
|
total_combos *= len(OPTIMAL_COMPONENTS[c['type']]['solutions']) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [COMP] Total combinations: {total_combos:,}") |
|
|
|
|
|
if total_combos > 10_000_000: |
|
|
if cfg.verbose: |
|
|
print(f" [COMP] Too many combinations, using sampling...") |
|
|
|
|
|
n_samples = min(1_000_000, total_combos) |
|
|
valid_solutions = [] |
|
|
|
|
|
for _ in range(n_samples): |
|
|
weights = circuit.clone_weights() |
|
|
total_mag = 0 |
|
|
|
|
|
for comp in components: |
|
|
if comp['type'] not in OPTIMAL_COMPONENTS: |
|
|
continue |
|
|
solutions = OPTIMAL_COMPONENTS[comp['type']]['solutions'] |
|
|
sol = random.choice(solutions) |
|
|
|
|
|
|
|
|
|
|
|
total_mag += OPTIMAL_COMPONENTS[comp['type']]['magnitude'] |
|
|
|
|
|
|
|
|
fitness = evaluator.evaluate_single(weights) |
|
|
if fitness >= cfg.fitness_threshold: |
|
|
valid_solutions.append({'magnitude': total_mag, 'fitness': fitness}) |
|
|
|
|
|
if valid_solutions: |
|
|
best = min(valid_solutions, key=lambda x: x['magnitude']) |
|
|
if cfg.verbose: |
|
|
print(f" [COMP] Found {len(valid_solutions)} valid from {n_samples:,} samples") |
|
|
else: |
|
|
|
|
|
valid_solutions = [] |
|
|
tested = 0 |
|
|
|
|
|
|
|
|
solution_lists = [] |
|
|
for comp in components: |
|
|
if comp['type'] in OPTIMAL_COMPONENTS: |
|
|
solution_lists.append(list(range(len(OPTIMAL_COMPONENTS[comp['type']]['solutions'])))) |
|
|
else: |
|
|
solution_lists.append([0]) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [COMP] Enumerating {total_combos:,} combinations...") |
|
|
|
|
|
for combo in product(*solution_lists): |
|
|
tested += 1 |
|
|
total_mag = 0 |
|
|
|
|
|
for i, comp in enumerate(components): |
|
|
if comp['type'] in OPTIMAL_COMPONENTS: |
|
|
total_mag += OPTIMAL_COMPONENTS[comp['type']]['magnitude'] |
|
|
|
|
|
|
|
|
|
|
|
valid_solutions.append({'combo': combo, 'magnitude': total_mag}) |
|
|
|
|
|
if cfg.verbose and tested % 10000 == 0: |
|
|
print(f" Tested {tested:,}/{total_combos:,}...", end='\r') |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" Tested {tested:,}/{total_combos:,} - done") |
|
|
|
|
|
|
|
|
theoretical_mag = sum(OPTIMAL_COMPONENTS[c['type']]['magnitude'] for c in components if c['type'] in OPTIMAL_COMPONENTS) |
|
|
|
|
|
if cfg.verbose: |
|
|
print(f" [COMP] Theoretical optimal magnitude: {theoretical_mag}") |
|
|
print(f" [COMP] Original magnitude: {original['magnitude']:.0f}") |
|
|
if theoretical_mag < original['magnitude']: |
|
|
print(f" [COMP] Potential reduction: {(1 - theoretical_mag/original['magnitude'])*100:.1f}%") |
|
|
|
|
|
return PruneResult( |
|
|
method='compositional', |
|
|
original_stats=original, |
|
|
final_stats=original, |
|
|
final_weights=circuit.clone_weights(), |
|
|
fitness=evaluator.evaluate_single(circuit.weights), |
|
|
time_seconds=time.perf_counter() - start, |
|
|
metadata={ |
|
|
'components': [(c['type'], c['prefix']) for c in components], |
|
|
'total_combinations': total_combos, |
|
|
'theoretical_magnitude': theoretical_mag |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]: |
|
|
"""Run all enabled pruning methods.""" |
|
|
print(f"\n{'=' * 70}") |
|
|
print(f" PRUNING: {circuit.spec.name}") |
|
|
print(f"{'=' * 70}") |
|
|
|
|
|
usage = cfg.vram.current_usage() |
|
|
print(f" VRAM: {cfg.vram.total_gb:.1f} GB total, {usage['free_gb']:.1f} GB free") |
|
|
print(f" Device: {cfg.vram.device_name}") |
|
|
|
|
|
original = circuit.stats() |
|
|
print(f" Inputs: {circuit.spec.inputs}, Outputs: {circuit.spec.outputs}") |
|
|
print(f" Neurons: {circuit.spec.neurons}, Layers: {circuit.spec.layers}") |
|
|
print(f" Parameters: {original['total']}, Non-zero: {original['nonzero']}") |
|
|
print(f" Magnitude: {original['magnitude']:.0f}") |
|
|
print(f" Test cases: {circuit.n_cases}") |
|
|
print(f"{'=' * 70}") |
|
|
|
|
|
evaluator = BatchedEvaluator(circuit, cfg) |
|
|
initial_fitness = evaluator.evaluate_single(circuit.weights) |
|
|
print(f"\n Initial fitness: {initial_fitness:.6f}") |
|
|
|
|
|
if initial_fitness < cfg.fitness_threshold: |
|
|
print(" ERROR: Circuit doesn't pass baseline!") |
|
|
return {} |
|
|
|
|
|
results = {} |
|
|
methods = [ |
|
|
('magnitude', cfg.run_magnitude, lambda: prune_magnitude(circuit, evaluator, cfg)), |
|
|
('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)), |
|
|
('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)), |
|
|
('exhaustive_mag', cfg.run_exhaustive_mag, lambda: prune_exhaustive_mag(circuit, evaluator, cfg)), |
|
|
('architecture', cfg.run_architecture, lambda: prune_architecture(circuit, evaluator, cfg)), |
|
|
('compositional', cfg.run_compositional, lambda: prune_compositional(circuit, evaluator, cfg)), |
|
|
] |
|
|
|
|
|
enabled = [(n, fn) for n, enabled, fn in methods if enabled] |
|
|
print(f"\n Running {len(enabled)} pruning methods...") |
|
|
print(f"{'=' * 70}") |
|
|
|
|
|
for i, (name, fn) in enumerate(enabled): |
|
|
print(f"\n[{i + 1}/{len(enabled)}] {name.upper()}") |
|
|
print("-" * 50) |
|
|
try: |
|
|
clear_vram() |
|
|
results[name] = fn() |
|
|
r = results[name] |
|
|
print(f" Fitness: {r.fitness:.6f}, Magnitude: {r.final_stats.get('magnitude', 0):.0f}, Time: {r.time_seconds:.1f}s") |
|
|
except Exception as e: |
|
|
print(f" ERROR: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
print(f"\n{'=' * 70}") |
|
|
print(" SUMMARY") |
|
|
print(f"{'=' * 70}") |
|
|
print(f"\n{'Method':<15} {'Fitness':<10} {'Magnitude':<12} {'Reduction':<12} {'Time':<10}") |
|
|
print("-" * 60) |
|
|
print(f"{'Original':<15} {'1.0000':<10} {original['magnitude']:<12.0f} {'-':<12} {'-':<10}") |
|
|
|
|
|
best_method, best_mag = None, float('inf') |
|
|
for name, r in sorted(results.items(), key=lambda x: x[1].final_stats.get('magnitude', float('inf'))): |
|
|
mag = r.final_stats.get('magnitude', 0) |
|
|
reduction = f"{(1 - mag/original['magnitude'])*100:.1f}%" if mag < original['magnitude'] else "-" |
|
|
print(f"{name:<15} {r.fitness:<10.4f} {mag:<12.0f} {reduction:<12} {r.time_seconds:<10.1f}s") |
|
|
if r.fitness >= cfg.fitness_threshold and mag < best_mag: |
|
|
best_mag, best_method = mag, name |
|
|
|
|
|
if best_method: |
|
|
print(f"\n BEST: {best_method} ({(1 - best_mag/original['magnitude'])*100:.1f}% reduction)") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def discover_circuits(base: Path = CIRCUITS_PATH) -> List[CircuitSpec]: |
|
|
"""Find all circuits.""" |
|
|
circuits = [] |
|
|
for d in base.iterdir(): |
|
|
if d.is_dir() and (d / 'config.json').exists() and list(d.glob('*.safetensors')): |
|
|
try: |
|
|
with open(d / 'config.json') as f: |
|
|
cfg = json.load(f) |
|
|
circuits.append(CircuitSpec( |
|
|
name=cfg.get('name', d.name), |
|
|
path=d, |
|
|
inputs=cfg.get('inputs', 0), |
|
|
outputs=cfg.get('outputs', 0), |
|
|
neurons=cfg.get('neurons', 0), |
|
|
layers=cfg.get('layers', 0), |
|
|
parameters=cfg.get('parameters', 0) |
|
|
)) |
|
|
except: |
|
|
pass |
|
|
return sorted(circuits, key=lambda x: (x.inputs, x.neurons)) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Threshold Circuit Pruner v5') |
|
|
parser.add_argument('circuit', nargs='?', help='Circuit name') |
|
|
parser.add_argument('--weights', type=str, help='Specific .safetensors file') |
|
|
parser.add_argument('--list', action='store_true') |
|
|
parser.add_argument('--all', action='store_true') |
|
|
parser.add_argument('--max-inputs', type=int, default=10) |
|
|
parser.add_argument('--device', default='cuda') |
|
|
parser.add_argument('--methods', type=str, help='Comma-separated: mag,zero,evo,exh,arch,comp') |
|
|
parser.add_argument('--fitness', type=float, default=0.9999) |
|
|
parser.add_argument('--quiet', action='store_true') |
|
|
parser.add_argument('--save', action='store_true') |
|
|
parser.add_argument('--evo-pop', type=int, default=0) |
|
|
parser.add_argument('--evo-gens', type=int, default=2000) |
|
|
parser.add_argument('--exhaustive-max-params', type=int, default=12) |
|
|
parser.add_argument('--target-mag', type=int, default=-1) |
|
|
parser.add_argument('--arch-hidden', type=int, default=3) |
|
|
parser.add_argument('--arch-max-weight', type=int, default=3) |
|
|
parser.add_argument('--arch-max-mag', type=int, default=20) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.list: |
|
|
specs = discover_circuits() |
|
|
print(f"\nAvailable circuits ({len(specs)}):\n") |
|
|
for s in specs: |
|
|
print(f" {s.name:<40} {s.inputs}in/{s.outputs}out {s.neurons}N {s.parameters}P") |
|
|
return |
|
|
|
|
|
vram_cfg = VRAMConfig() |
|
|
cfg = Config( |
|
|
device=args.device, |
|
|
fitness_threshold=args.fitness, |
|
|
verbose=not args.quiet, |
|
|
vram=vram_cfg, |
|
|
evo_pop_size=args.evo_pop, |
|
|
evo_generations=args.evo_gens, |
|
|
exhaustive_max_params=args.exhaustive_max_params, |
|
|
exhaustive_target_mag=args.target_mag, |
|
|
arch_hidden_neurons=args.arch_hidden, |
|
|
arch_max_weight=args.arch_max_weight, |
|
|
arch_max_mag=args.arch_max_mag |
|
|
) |
|
|
|
|
|
if args.methods: |
|
|
method_map = { |
|
|
'mag': 'magnitude', 'magnitude': 'magnitude', |
|
|
'zero': 'zero', |
|
|
'evo': 'evolutionary', 'evolutionary': 'evolutionary', |
|
|
'exh': 'exhaustive_mag', 'exh_mag': 'exhaustive_mag', 'exhaustive': 'exhaustive_mag', |
|
|
'arch': 'architecture', 'architecture': 'architecture', |
|
|
'comp': 'compositional', 'compositional': 'compositional' |
|
|
} |
|
|
for m in args.methods.lower().split(','): |
|
|
m = m.strip() |
|
|
if m in method_map: |
|
|
setattr(cfg, f'run_{method_map[m]}', True) |
|
|
|
|
|
RESULTS_PATH.mkdir(exist_ok=True) |
|
|
|
|
|
if args.all: |
|
|
specs = [s for s in discover_circuits() if s.inputs <= args.max_inputs] |
|
|
print(f"\nRunning on {len(specs)} circuits...") |
|
|
for spec in specs: |
|
|
try: |
|
|
circuit = AdaptiveCircuit(spec.path, cfg.device) |
|
|
run_all_methods(circuit, cfg) |
|
|
clear_vram() |
|
|
except Exception as e: |
|
|
print(f"ERROR on {spec.name}: {e}") |
|
|
elif args.circuit: |
|
|
path = CIRCUITS_PATH / args.circuit |
|
|
if not path.exists(): |
|
|
path = CIRCUITS_PATH / f'threshold-{args.circuit}' |
|
|
if not path.exists(): |
|
|
print(f"Circuit not found: {args.circuit}") |
|
|
return |
|
|
|
|
|
circuit = AdaptiveCircuit(path, cfg.device, args.weights) |
|
|
results = run_all_methods(circuit, cfg) |
|
|
|
|
|
if args.save and results: |
|
|
best = min(results.values(), key=lambda r: r.final_stats.get('magnitude', float('inf'))) |
|
|
if best.fitness >= cfg.fitness_threshold: |
|
|
path = circuit.save_weights(best.final_weights, f'pruned_{best.method}') |
|
|
print(f"\nSaved to: {path}") |
|
|
else: |
|
|
parser.print_help() |
|
|
print("\n\nExamples:") |
|
|
print(" python prune.py --list") |
|
|
print(" python prune.py threshold-xor --methods evo") |
|
|
print(" python prune.py threshold-xor --methods exh --exhaustive-max-params 20") |
|
|
print(" python prune.py threshold-crc16-mag53 --methods comp") |
|
|
print(" python prune.py --all --max-inputs 8 --methods mag,zero") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|