CharlesCNorton commited on
Commit ·
1721000
1
Parent(s): 43d198d
Remove Coq export functionality from pruner
Browse files- Remove export_coq config option and coq_output_dir
- Remove export_coq() function
- Remove --export-coq CLI argument
- Update method count in docstring (17 -> 15)
prune.py
CHANGED
|
@@ -2,14 +2,13 @@
|
|
| 2 |
Threshold Circuit Pruner
|
| 3 |
|
| 4 |
Comprehensive pruning framework for threshold logic circuits.
|
| 5 |
-
Supports
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
python prune.py threshold-hamming74decoder
|
| 9 |
python prune.py threshold-hamming74decoder --methods evo,depth,symmetry
|
| 10 |
python prune.py --list
|
| 11 |
python prune.py --all --max-inputs 8
|
| 12 |
-
python prune.py threshold-xor --export-coq
|
| 13 |
"""
|
| 14 |
|
| 15 |
import torch
|
|
@@ -126,24 +125,22 @@ class Config:
|
|
| 126 |
run_quantize: bool = True
|
| 127 |
run_evolutionary: bool = True
|
| 128 |
run_annealing: bool = True
|
| 129 |
-
|
| 130 |
-
run_lottery: bool = True
|
| 131 |
run_topology: bool = True
|
| 132 |
-
run_structured: bool = True
|
| 133 |
run_sensitivity: bool = True
|
| 134 |
run_weight_sharing: bool = True
|
| 135 |
-
run_random: bool = True
|
| 136 |
-
run_pareto: bool = True
|
| 137 |
run_depth: bool = True
|
| 138 |
run_gate_subst: bool = True
|
| 139 |
run_symmetry: bool = True
|
| 140 |
run_fanin: bool = True
|
| 141 |
-
|
|
|
|
| 142 |
|
| 143 |
magnitude_passes: int = 100
|
| 144 |
|
| 145 |
exhaustive_range: int = 2
|
| 146 |
exhaustive_max_params: int = 12
|
|
|
|
| 147 |
|
| 148 |
evo_generations: int = 2000
|
| 149 |
evo_pop_size: int = 0
|
|
@@ -161,10 +158,6 @@ class Config:
|
|
| 161 |
annealing_parallel_chains: int = 0
|
| 162 |
|
| 163 |
quantize_targets: List[float] = field(default_factory=lambda: [-1.0, 0.0, 1.0])
|
| 164 |
-
pareto_levels: List[float] = field(default_factory=lambda: [1.0, 0.99, 0.95, 0.90, 0.80])
|
| 165 |
-
|
| 166 |
-
lottery_rounds: int = 10
|
| 167 |
-
lottery_prune_rate: float = 0.2
|
| 168 |
|
| 169 |
topology_generations: int = 500
|
| 170 |
topology_remove_prob: float = 0.2
|
|
@@ -172,14 +165,9 @@ class Config:
|
|
| 172 |
|
| 173 |
sensitivity_samples: int = 1000
|
| 174 |
|
| 175 |
-
random_iterations: int = 10000
|
| 176 |
-
|
| 177 |
depth_max_collapse: int = 3
|
| 178 |
fanin_target: int = 4
|
| 179 |
|
| 180 |
-
export_coq: bool = False
|
| 181 |
-
coq_output_dir: Path = field(default_factory=lambda: Path('D:/threshold-pruner/coq_exports'))
|
| 182 |
-
|
| 183 |
|
| 184 |
@dataclass
|
| 185 |
class CircuitSpec:
|
|
@@ -260,6 +248,7 @@ class ComputationGraph:
|
|
| 260 |
}
|
| 261 |
self.layer_groups[depth].append(neuron_name)
|
| 262 |
|
|
|
|
| 263 |
self._identify_outputs()
|
| 264 |
|
| 265 |
def _estimate_depth(self, name: str) -> int:
|
|
@@ -285,6 +274,36 @@ class ComputationGraph:
|
|
| 285 |
|
| 286 |
return depth
|
| 287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
def _identify_outputs(self):
|
| 289 |
"""Identify which neurons are outputs based on n_outputs and topology."""
|
| 290 |
|
|
@@ -398,11 +417,17 @@ class ComputationGraph:
|
|
| 398 |
return outputs
|
| 399 |
|
| 400 |
def _get_neuron_input(self, neuron_name: str, activations: Dict, raw_input: torch.Tensor, expected_size: int) -> torch.Tensor:
|
| 401 |
-
"""Determine input for a neuron based on naming conventions."""
|
| 402 |
-
|
|
|
|
|
|
|
| 403 |
return raw_input
|
| 404 |
|
| 405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
if 'layer2' in neuron_name:
|
| 408 |
base = neuron_name.replace('.layer2', '')
|
|
@@ -1058,6 +1083,36 @@ class BatchedEvaluator:
|
|
| 1058 |
if original_fitness < 0.999 and self.cfg.verbose:
|
| 1059 |
print(f" [EVAL ERROR] Native eval fitness={original_fitness:.4f}")
|
| 1060 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1061 |
def _setup_vmap(self):
|
| 1062 |
"""Setup vmap-based parallel evaluation."""
|
| 1063 |
try:
|
|
@@ -1123,24 +1178,25 @@ class BatchedEvaluator:
|
|
| 1123 |
|
| 1124 |
input_size = w_shape[-1] if w_shape else 0
|
| 1125 |
|
| 1126 |
-
input_source = 'raw'
|
| 1127 |
-
input_neurons = []
|
| 1128 |
-
|
| 1129 |
-
if
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
-
|
| 1133 |
-
|
| 1134 |
-
|
| 1135 |
-
|
| 1136 |
-
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
|
|
|
| 1144 |
|
| 1145 |
self.neuron_eval_order.append(neuron_name)
|
| 1146 |
self.neuron_weight_slices[neuron_name] = {
|
|
@@ -1836,194 +1892,6 @@ def prune_annealing(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
|
|
| 1836 |
)
|
| 1837 |
|
| 1838 |
|
| 1839 |
-
def prune_neuron(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 1840 |
-
"""
|
| 1841 |
-
Truly exhaustive neuron-level pruning.
|
| 1842 |
-
|
| 1843 |
-
For each neuron, fix its parameters to zero and exhaustively search
|
| 1844 |
-
all weight configurations for the REMAINING neurons. If any config
|
| 1845 |
-
works, the neuron is provably unnecessary.
|
| 1846 |
-
"""
|
| 1847 |
-
start = time.perf_counter()
|
| 1848 |
-
original = circuit.stats()
|
| 1849 |
-
|
| 1850 |
-
neuron_groups = defaultdict(list)
|
| 1851 |
-
for key in circuit.weights.keys():
|
| 1852 |
-
parts = key.rsplit('.', 1)
|
| 1853 |
-
neuron_name = parts[0] if len(parts) == 2 else key.split('.')[0]
|
| 1854 |
-
neuron_groups[neuron_name].append(key)
|
| 1855 |
-
|
| 1856 |
-
neuron_names = list(neuron_groups.keys())
|
| 1857 |
-
weight_keys = list(circuit.weights.keys())
|
| 1858 |
-
weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
|
| 1859 |
-
weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
|
| 1860 |
-
n_total_params = sum(weight_sizes.values())
|
| 1861 |
-
|
| 1862 |
-
if cfg.verbose:
|
| 1863 |
-
print(f" Found {len(neuron_groups)} neuron groups")
|
| 1864 |
-
print(f" Total parameters: {n_total_params}")
|
| 1865 |
-
|
| 1866 |
-
search_range = cfg.exhaustive_range
|
| 1867 |
-
values = list(range(-search_range, search_range + 1))
|
| 1868 |
-
|
| 1869 |
-
def vector_to_weights(vec):
|
| 1870 |
-
weights = {}
|
| 1871 |
-
idx = 0
|
| 1872 |
-
for k in weight_keys:
|
| 1873 |
-
size = weight_sizes[k]
|
| 1874 |
-
weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k])
|
| 1875 |
-
idx += size
|
| 1876 |
-
return weights
|
| 1877 |
-
|
| 1878 |
-
def get_neuron_param_indices(neuron_name):
|
| 1879 |
-
indices = []
|
| 1880 |
-
idx = 0
|
| 1881 |
-
for k in weight_keys:
|
| 1882 |
-
size = weight_sizes[k]
|
| 1883 |
-
if k in neuron_groups[neuron_name]:
|
| 1884 |
-
indices.extend(range(idx, idx + size))
|
| 1885 |
-
idx += size
|
| 1886 |
-
return indices
|
| 1887 |
-
|
| 1888 |
-
best_weights = circuit.clone_weights()
|
| 1889 |
-
best_neurons_removed = 0
|
| 1890 |
-
removed_neuron_names = []
|
| 1891 |
-
|
| 1892 |
-
for neuron_name in neuron_names:
|
| 1893 |
-
neuron_indices = set(get_neuron_param_indices(neuron_name))
|
| 1894 |
-
other_indices = [i for i in range(n_total_params) if i not in neuron_indices]
|
| 1895 |
-
n_other_params = len(other_indices)
|
| 1896 |
-
|
| 1897 |
-
if n_other_params > cfg.exhaustive_max_params:
|
| 1898 |
-
if cfg.verbose:
|
| 1899 |
-
print(f" [{neuron_name}] Skipping: {n_other_params} remaining params > max {cfg.exhaustive_max_params}")
|
| 1900 |
-
continue
|
| 1901 |
-
|
| 1902 |
-
search_space = (2 * search_range + 1) ** n_other_params
|
| 1903 |
-
|
| 1904 |
-
if cfg.verbose:
|
| 1905 |
-
print(f" [{neuron_name}] Testing removal: searching {search_space:,} configs for {n_other_params} remaining params")
|
| 1906 |
-
|
| 1907 |
-
found_valid = False
|
| 1908 |
-
best_config = None
|
| 1909 |
-
best_mag = float('inf')
|
| 1910 |
-
tested = 0
|
| 1911 |
-
report_interval = max(1, search_space // 10)
|
| 1912 |
-
|
| 1913 |
-
for combo in product(values, repeat=n_other_params):
|
| 1914 |
-
tested += 1
|
| 1915 |
-
|
| 1916 |
-
full_vec = [0] * n_total_params
|
| 1917 |
-
for i, val in zip(other_indices, combo):
|
| 1918 |
-
full_vec[i] = val
|
| 1919 |
-
|
| 1920 |
-
weights = vector_to_weights(full_vec)
|
| 1921 |
-
fitness = evaluator.evaluate_single(weights)
|
| 1922 |
-
|
| 1923 |
-
if fitness >= cfg.fitness_threshold:
|
| 1924 |
-
mag = sum(abs(v) for v in combo)
|
| 1925 |
-
if not found_valid or mag < best_mag:
|
| 1926 |
-
found_valid = True
|
| 1927 |
-
best_mag = mag
|
| 1928 |
-
best_config = full_vec
|
| 1929 |
-
|
| 1930 |
-
if cfg.verbose and tested % report_interval == 0:
|
| 1931 |
-
elapsed = time.perf_counter() - start
|
| 1932 |
-
pct = 100 * tested / search_space
|
| 1933 |
-
print(f" [{elapsed:6.1f}s] {pct:5.1f}% | valid={'YES' if found_valid else 'no '} | best_mag={best_mag if found_valid else '-'}")
|
| 1934 |
-
|
| 1935 |
-
if found_valid:
|
| 1936 |
-
if cfg.verbose:
|
| 1937 |
-
print(f" [REMOVABLE] {neuron_name} can be removed! Best remaining mag={best_mag}")
|
| 1938 |
-
removed_neuron_names.append(neuron_name)
|
| 1939 |
-
best_neurons_removed += 1
|
| 1940 |
-
best_weights = vector_to_weights(best_config)
|
| 1941 |
-
else:
|
| 1942 |
-
if cfg.verbose:
|
| 1943 |
-
print(f" [REQUIRED] {neuron_name} is necessary")
|
| 1944 |
-
|
| 1945 |
-
if cfg.verbose:
|
| 1946 |
-
elapsed = time.perf_counter() - start
|
| 1947 |
-
print(f" [NEURON COMPLETE] {best_neurons_removed} neurons removable out of {len(neuron_names)}")
|
| 1948 |
-
print(f" Removable: {removed_neuron_names if removed_neuron_names else 'none'}")
|
| 1949 |
-
print(f" Time: {elapsed:.1f}s")
|
| 1950 |
-
|
| 1951 |
-
return PruneResult(
|
| 1952 |
-
method='neuron',
|
| 1953 |
-
original_stats=original,
|
| 1954 |
-
final_stats=circuit.stats(best_weights),
|
| 1955 |
-
final_weights=best_weights,
|
| 1956 |
-
fitness=evaluator.evaluate_single(best_weights),
|
| 1957 |
-
time_seconds=time.perf_counter() - start,
|
| 1958 |
-
metadata={'neurons_removed': best_neurons_removed, 'removed_names': removed_neuron_names}
|
| 1959 |
-
)
|
| 1960 |
-
|
| 1961 |
-
|
| 1962 |
-
def prune_lottery(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 1963 |
-
"""Lottery Ticket pruning."""
|
| 1964 |
-
start = time.perf_counter()
|
| 1965 |
-
original = circuit.stats()
|
| 1966 |
-
|
| 1967 |
-
weights = circuit.clone_weights()
|
| 1968 |
-
initial = circuit.clone_weights()
|
| 1969 |
-
history = []
|
| 1970 |
-
|
| 1971 |
-
if cfg.verbose:
|
| 1972 |
-
print(f" Lottery: {cfg.lottery_rounds} rounds, {cfg.lottery_prune_rate * 100:.0f}% per round")
|
| 1973 |
-
|
| 1974 |
-
mask = {k: torch.ones_like(v) for k, v in weights.items()}
|
| 1975 |
-
|
| 1976 |
-
for rnd in range(cfg.lottery_rounds):
|
| 1977 |
-
all_weights = []
|
| 1978 |
-
for name, tensor in weights.items():
|
| 1979 |
-
flat = tensor.flatten()
|
| 1980 |
-
m_flat = mask[name].flatten()
|
| 1981 |
-
for i in range(len(flat)):
|
| 1982 |
-
if m_flat[i] > 0 and flat[i].item() != 0:
|
| 1983 |
-
all_weights.append((abs(flat[i].item()), name, i))
|
| 1984 |
-
|
| 1985 |
-
if not all_weights:
|
| 1986 |
-
break
|
| 1987 |
-
|
| 1988 |
-
all_weights.sort(key=lambda x: x[0])
|
| 1989 |
-
n_prune = max(1, int(len(all_weights) * cfg.lottery_prune_rate))
|
| 1990 |
-
to_prune = all_weights[:n_prune]
|
| 1991 |
-
|
| 1992 |
-
for _, name, idx in to_prune:
|
| 1993 |
-
m_flat = mask[name].flatten()
|
| 1994 |
-
m_flat[idx] = 0
|
| 1995 |
-
mask[name] = m_flat.view(mask[name].shape)
|
| 1996 |
-
|
| 1997 |
-
for name in weights:
|
| 1998 |
-
weights[name] = initial[name] * mask[name]
|
| 1999 |
-
|
| 2000 |
-
fitness = evaluator.evaluate_single(weights)
|
| 2001 |
-
stats = circuit.stats(weights)
|
| 2002 |
-
history.append({'round': rnd, 'pruned': n_prune, 'fitness': fitness, 'magnitude': stats['magnitude']})
|
| 2003 |
-
|
| 2004 |
-
if cfg.verbose:
|
| 2005 |
-
print(f" Round {rnd}: pruned {n_prune}, fitness={fitness:.4f}")
|
| 2006 |
-
|
| 2007 |
-
if fitness < cfg.fitness_threshold:
|
| 2008 |
-
for _, name, idx in to_prune:
|
| 2009 |
-
m_flat = mask[name].flatten()
|
| 2010 |
-
m_flat[idx] = 1
|
| 2011 |
-
mask[name] = m_flat.view(mask[name].shape)
|
| 2012 |
-
for name in weights:
|
| 2013 |
-
weights[name] = initial[name] * mask[name]
|
| 2014 |
-
break
|
| 2015 |
-
|
| 2016 |
-
return PruneResult(
|
| 2017 |
-
method='lottery',
|
| 2018 |
-
original_stats=original,
|
| 2019 |
-
final_stats=circuit.stats(weights),
|
| 2020 |
-
final_weights=weights,
|
| 2021 |
-
fitness=evaluator.evaluate_single(weights),
|
| 2022 |
-
time_seconds=time.perf_counter() - start,
|
| 2023 |
-
history=history
|
| 2024 |
-
)
|
| 2025 |
-
|
| 2026 |
-
|
| 2027 |
def prune_topology(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2028 |
"""Topology search - remove connection groups."""
|
| 2029 |
start = time.perf_counter()
|
|
@@ -2092,14 +1960,14 @@ def prune_topology(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: C
|
|
| 2092 |
)
|
| 2093 |
|
| 2094 |
|
| 2095 |
-
def
|
| 2096 |
-
"""
|
| 2097 |
start = time.perf_counter()
|
| 2098 |
weights = circuit.clone_weights()
|
| 2099 |
original = circuit.stats(weights)
|
| 2100 |
|
| 2101 |
if cfg.verbose:
|
| 2102 |
-
print(f"
|
| 2103 |
|
| 2104 |
removed = 0
|
| 2105 |
|
|
@@ -2137,7 +2005,7 @@ def prune_structured(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
|
|
| 2137 |
print(f" Removed {removed} rows/columns")
|
| 2138 |
|
| 2139 |
return PruneResult(
|
| 2140 |
-
method='
|
| 2141 |
original_stats=original,
|
| 2142 |
final_stats=circuit.stats(weights),
|
| 2143 |
final_weights=weights,
|
|
@@ -2279,99 +2147,6 @@ def prune_weight_sharing(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator,
|
|
| 2279 |
)
|
| 2280 |
|
| 2281 |
|
| 2282 |
-
def prune_random(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2283 |
-
"""Random search baseline."""
|
| 2284 |
-
start = time.perf_counter()
|
| 2285 |
-
original = circuit.stats()
|
| 2286 |
-
|
| 2287 |
-
if cfg.verbose:
|
| 2288 |
-
print(f" Random search ({cfg.random_iterations:,} iterations)...")
|
| 2289 |
-
|
| 2290 |
-
base_vector = circuit.weights_to_vector(circuit.weights)
|
| 2291 |
-
best_weights = circuit.clone_weights()
|
| 2292 |
-
best_mag = sum(t.abs().sum().item() for t in best_weights.values())
|
| 2293 |
-
best_fitness = evaluator.evaluate_single(best_weights)
|
| 2294 |
-
|
| 2295 |
-
n_valid = 0
|
| 2296 |
-
|
| 2297 |
-
batch_size = min(10000, evaluator.max_batch)
|
| 2298 |
-
n_batches = cfg.random_iterations // batch_size
|
| 2299 |
-
|
| 2300 |
-
for batch in range(n_batches):
|
| 2301 |
-
population = base_vector.unsqueeze(0).expand(batch_size, -1).clone()
|
| 2302 |
-
noise = torch.randn_like(population) * 2
|
| 2303 |
-
population = (population + noise).round()
|
| 2304 |
-
|
| 2305 |
-
fitness = evaluator.evaluate_population(population)
|
| 2306 |
-
|
| 2307 |
-
valid_mask = fitness >= cfg.fitness_threshold
|
| 2308 |
-
n_valid += valid_mask.sum().item()
|
| 2309 |
-
|
| 2310 |
-
if valid_mask.any():
|
| 2311 |
-
magnitudes = population.abs().sum(dim=1)
|
| 2312 |
-
magnitudes[~valid_mask] = float('inf')
|
| 2313 |
-
best_idx = magnitudes.argmin().item()
|
| 2314 |
-
|
| 2315 |
-
if magnitudes[best_idx] < best_mag:
|
| 2316 |
-
best_mag = magnitudes[best_idx].item()
|
| 2317 |
-
best_weights = circuit.vector_to_weights(population[best_idx])
|
| 2318 |
-
best_fitness = fitness[best_idx].item()
|
| 2319 |
-
|
| 2320 |
-
if cfg.verbose and batch % 10 == 0:
|
| 2321 |
-
print(f" Batch {batch}/{n_batches}: valid={n_valid}, best_mag={best_mag:.0f}")
|
| 2322 |
-
|
| 2323 |
-
return PruneResult(
|
| 2324 |
-
method='random',
|
| 2325 |
-
original_stats=original,
|
| 2326 |
-
final_stats=circuit.stats(best_weights),
|
| 2327 |
-
final_weights=best_weights,
|
| 2328 |
-
fitness=best_fitness,
|
| 2329 |
-
time_seconds=time.perf_counter() - start,
|
| 2330 |
-
metadata={'total_valid': n_valid, 'total_tested': cfg.random_iterations}
|
| 2331 |
-
)
|
| 2332 |
-
|
| 2333 |
-
|
| 2334 |
-
def prune_pareto(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2335 |
-
"""Explore Pareto frontier of correctness vs. size."""
|
| 2336 |
-
start = time.perf_counter()
|
| 2337 |
-
original = circuit.stats()
|
| 2338 |
-
frontier = []
|
| 2339 |
-
|
| 2340 |
-
if cfg.verbose:
|
| 2341 |
-
print(f" Exploring Pareto frontier...")
|
| 2342 |
-
|
| 2343 |
-
for target in cfg.pareto_levels:
|
| 2344 |
-
relaxed_cfg = Config(
|
| 2345 |
-
device=cfg.device,
|
| 2346 |
-
fitness_threshold=target,
|
| 2347 |
-
magnitude_passes=30,
|
| 2348 |
-
verbose=False,
|
| 2349 |
-
vram=cfg.vram
|
| 2350 |
-
)
|
| 2351 |
-
|
| 2352 |
-
result = prune_magnitude(circuit, evaluator, relaxed_cfg)
|
| 2353 |
-
|
| 2354 |
-
frontier.append({
|
| 2355 |
-
'target': target,
|
| 2356 |
-
'actual': result.fitness,
|
| 2357 |
-
'magnitude': result.final_stats['magnitude'],
|
| 2358 |
-
'nonzero': result.final_stats['nonzero']
|
| 2359 |
-
})
|
| 2360 |
-
|
| 2361 |
-
if cfg.verbose:
|
| 2362 |
-
print(f" Target {target:.2f}: fitness={result.fitness:.4f}, mag={result.final_stats['magnitude']:.0f}")
|
| 2363 |
-
|
| 2364 |
-
return PruneResult(
|
| 2365 |
-
method='pareto',
|
| 2366 |
-
original_stats=original,
|
| 2367 |
-
final_stats=frontier[-1] if frontier else original,
|
| 2368 |
-
final_weights=circuit.clone_weights(),
|
| 2369 |
-
fitness=frontier[0]['actual'] if frontier else 1.0,
|
| 2370 |
-
time_seconds=time.perf_counter() - start,
|
| 2371 |
-
history=frontier
|
| 2372 |
-
)
|
| 2373 |
-
|
| 2374 |
-
|
| 2375 |
def prune_depth(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2376 |
"""
|
| 2377 |
Depth reduction - attempt to collapse consecutive layers.
|
|
@@ -2790,9 +2565,9 @@ def _configs_at_magnitude(mag: int, n_params: int):
|
|
| 2790 |
yield tuple(signed)
|
| 2791 |
|
| 2792 |
|
| 2793 |
-
def
|
| 2794 |
"""
|
| 2795 |
-
Exhaustive search by magnitude level - finds provably
|
| 2796 |
|
| 2797 |
Searches magnitude 0, then 1, then 2, ... until valid solutions found.
|
| 2798 |
Returns ALL valid solutions at the minimum magnitude (to discover families).
|
|
@@ -2924,7 +2699,7 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
|
|
| 2924 |
print(f" - Time: {elapsed:.1f}s")
|
| 2925 |
|
| 2926 |
return PruneResult(
|
| 2927 |
-
method='
|
| 2928 |
original_stats=original,
|
| 2929 |
final_stats=circuit.stats(best_weights),
|
| 2930 |
final_weights=best_weights,
|
|
@@ -2939,71 +2714,194 @@ def prune_exhaustive(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg:
|
|
| 2939 |
)
|
| 2940 |
|
| 2941 |
|
| 2942 |
-
def
|
| 2943 |
"""
|
| 2944 |
-
|
| 2945 |
-
|
| 2946 |
"""
|
| 2947 |
-
|
| 2948 |
-
|
| 2949 |
-
|
| 2950 |
-
|
| 2951 |
-
|
| 2952 |
-
|
| 2953 |
-
|
| 2954 |
-
|
| 2955 |
-
|
| 2956 |
-
|
| 2957 |
-
|
| 2958 |
-
|
| 2959 |
-
|
| 2960 |
-
|
| 2961 |
-
|
| 2962 |
-
|
| 2963 |
-
|
| 2964 |
-
|
| 2965 |
-
|
| 2966 |
-
|
| 2967 |
-
|
| 2968 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2969 |
|
| 2970 |
-
|
| 2971 |
-
|
| 2972 |
-
info = circuit.graph.neurons[neuron_name]
|
| 2973 |
-
w_key = info.get('weight_key')
|
| 2974 |
-
b_key = info.get('bias_key')
|
| 2975 |
|
| 2976 |
-
|
| 2977 |
-
|
| 2978 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2979 |
|
| 2980 |
-
|
| 2981 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2982 |
|
| 2983 |
-
|
| 2984 |
-
|
| 2985 |
-
|
|
|
|
| 2986 |
|
| 2987 |
-
|
| 2988 |
|
| 2989 |
-
|
| 2990 |
-
|
| 2991 |
-
|
| 2992 |
-
|
| 2993 |
-
|
| 2994 |
-
|
| 2995 |
-
|
| 2996 |
-
|
| 2997 |
-
'Proof.',
|
| 2998 |
-
' intros.',
|
| 2999 |
-
' (* Proof to be completed *)',
|
| 3000 |
-
'Admitted.',
|
| 3001 |
-
])
|
| 3002 |
|
| 3003 |
-
|
| 3004 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3005 |
|
| 3006 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3007 |
|
| 3008 |
|
| 3009 |
def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]:
|
|
@@ -3040,21 +2938,18 @@ def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneRes
|
|
| 3040 |
('magnitude', cfg.run_magnitude, lambda: prune_magnitude(circuit, evaluator, cfg)),
|
| 3041 |
('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)),
|
| 3042 |
('quantize', cfg.run_quantize, lambda: prune_quantize(circuit, evaluator, cfg)),
|
| 3043 |
-
('
|
| 3044 |
-
('lottery', cfg.run_lottery, lambda: prune_lottery(circuit, evaluator, cfg)),
|
| 3045 |
('topology', cfg.run_topology, lambda: prune_topology(circuit, evaluator, cfg)),
|
| 3046 |
-
('structured', cfg.run_structured, lambda: prune_structured(circuit, evaluator, cfg)),
|
| 3047 |
('sensitivity', cfg.run_sensitivity, lambda: prune_sensitivity(circuit, evaluator, cfg)),
|
| 3048 |
('weight_sharing', cfg.run_weight_sharing, lambda: prune_weight_sharing(circuit, evaluator, cfg)),
|
| 3049 |
('depth', cfg.run_depth, lambda: prune_depth(circuit, evaluator, cfg)),
|
| 3050 |
('gate_subst', cfg.run_gate_subst, lambda: prune_gate_substitution(circuit, evaluator, cfg)),
|
| 3051 |
('symmetry', cfg.run_symmetry, lambda: prune_symmetry(circuit, evaluator, cfg)),
|
| 3052 |
('fanin', cfg.run_fanin, lambda: prune_fanin(circuit, evaluator, cfg)),
|
| 3053 |
-
('
|
| 3054 |
-
('
|
| 3055 |
('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
|
| 3056 |
('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
|
| 3057 |
-
('pareto', cfg.run_pareto, lambda: prune_pareto(circuit, evaluator, cfg)),
|
| 3058 |
]
|
| 3059 |
|
| 3060 |
enabled_methods = [(name, fn) for name, enabled, fn in methods if enabled]
|
|
@@ -3100,10 +2995,6 @@ def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneRes
|
|
| 3100 |
reduction = 1 - best_mag / original['magnitude']
|
| 3101 |
print(f"\n BEST: {best_method} ({reduction * 100:.1f}% magnitude reduction)")
|
| 3102 |
|
| 3103 |
-
if cfg.export_coq:
|
| 3104 |
-
coq_path = export_coq(circuit, results[best_method].final_weights, cfg)
|
| 3105 |
-
print(f" Coq export: {coq_path}")
|
| 3106 |
-
|
| 3107 |
return results
|
| 3108 |
|
| 3109 |
|
|
@@ -3228,8 +3119,9 @@ def main():
|
|
| 3228 |
parser.add_argument('--sa-iters', type=int, default=50000, help='Simulated annealing iterations')
|
| 3229 |
parser.add_argument('--sa-chains', type=int, default=0, help='Parallel SA chains (0=auto)')
|
| 3230 |
parser.add_argument('--vram-target', type=float, default=0.75)
|
| 3231 |
-
parser.add_argument('--export-coq', action='store_true')
|
| 3232 |
parser.add_argument('--fanin-target', type=int, default=4)
|
|
|
|
|
|
|
| 3233 |
|
| 3234 |
args = parser.parse_args()
|
| 3235 |
|
|
@@ -3251,15 +3143,16 @@ def main():
|
|
| 3251 |
evo_generations=args.evo_gens,
|
| 3252 |
annealing_iterations=args.sa_iters,
|
| 3253 |
annealing_parallel_chains=args.sa_chains,
|
| 3254 |
-
|
| 3255 |
-
|
|
|
|
| 3256 |
)
|
| 3257 |
|
| 3258 |
if args.methods:
|
| 3259 |
all_methods = ['magnitude', 'zero', 'quantize', 'evolutionary', 'annealing',
|
| 3260 |
-
'
|
| 3261 |
-
'
|
| 3262 |
-
'
|
| 3263 |
for m in all_methods:
|
| 3264 |
setattr(cfg, f'run_{m}', False)
|
| 3265 |
|
|
@@ -3271,19 +3164,16 @@ def main():
|
|
| 3271 |
'quant': 'quantize', 'quantize': 'quantize',
|
| 3272 |
'evo': 'evolutionary', 'evolutionary': 'evolutionary',
|
| 3273 |
'anneal': 'annealing', 'annealing': 'annealing', 'sa': 'annealing',
|
| 3274 |
-
'
|
| 3275 |
-
'lottery': 'lottery',
|
| 3276 |
'topo': 'topology', 'topology': 'topology',
|
| 3277 |
-
'struct': 'structured', 'structured': 'structured',
|
| 3278 |
'sens': 'sensitivity', 'sensitivity': 'sensitivity',
|
| 3279 |
'share': 'weight_sharing', 'weight_sharing': 'weight_sharing', 'sharing': 'weight_sharing',
|
| 3280 |
-
'rand': 'random', 'random': 'random',
|
| 3281 |
-
'pareto': 'pareto',
|
| 3282 |
'depth': 'depth',
|
| 3283 |
'gate': 'gate_subst', 'gate_subst': 'gate_subst', 'subst': 'gate_subst',
|
| 3284 |
'sym': 'symmetry', 'symmetry': 'symmetry',
|
| 3285 |
'fanin': 'fanin', 'fan': 'fanin',
|
| 3286 |
-
'
|
|
|
|
| 3287 |
}
|
| 3288 |
if m in method_map:
|
| 3289 |
setattr(cfg, f'run_{method_map[m]}', True)
|
|
@@ -3343,6 +3233,11 @@ def main():
|
|
| 3343 |
print(" python prune.py threshold-hamming74decoder --methods evo")
|
| 3344 |
print(" python prune.py threshold-xor --methods evo --evo-pop 500000 --evo-gens 5000")
|
| 3345 |
print("")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3346 |
print(" # Pipeline mode (chained, each stage feeds into next):")
|
| 3347 |
print(" python prune.py threshold-hamming74decoder --pipeline evo,mag,zero,quant --save")
|
| 3348 |
print(" python prune.py threshold-xor --pipeline anneal,mag,zero --sa-iters 100000")
|
|
|
|
| 2 |
Threshold Circuit Pruner
|
| 3 |
|
| 4 |
Comprehensive pruning framework for threshold logic circuits.
|
| 5 |
+
Supports 15 pruning methods with GPU-optimized parallel evaluation.
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
python prune.py threshold-hamming74decoder
|
| 9 |
python prune.py threshold-hamming74decoder --methods evo,depth,symmetry
|
| 10 |
python prune.py --list
|
| 11 |
python prune.py --all --max-inputs 8
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import torch
|
|
|
|
| 125 |
run_quantize: bool = True
|
| 126 |
run_evolutionary: bool = True
|
| 127 |
run_annealing: bool = True
|
| 128 |
+
run_structural: bool = True
|
|
|
|
| 129 |
run_topology: bool = True
|
|
|
|
| 130 |
run_sensitivity: bool = True
|
| 131 |
run_weight_sharing: bool = True
|
|
|
|
|
|
|
| 132 |
run_depth: bool = True
|
| 133 |
run_gate_subst: bool = True
|
| 134 |
run_symmetry: bool = True
|
| 135 |
run_fanin: bool = True
|
| 136 |
+
run_exhaustive_mag: bool = True
|
| 137 |
+
run_exhaustive_sparse: bool = True
|
| 138 |
|
| 139 |
magnitude_passes: int = 100
|
| 140 |
|
| 141 |
exhaustive_range: int = 2
|
| 142 |
exhaustive_max_params: int = 12
|
| 143 |
+
sparse_max_weight: int = 3
|
| 144 |
|
| 145 |
evo_generations: int = 2000
|
| 146 |
evo_pop_size: int = 0
|
|
|
|
| 158 |
annealing_parallel_chains: int = 0
|
| 159 |
|
| 160 |
quantize_targets: List[float] = field(default_factory=lambda: [-1.0, 0.0, 1.0])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
topology_generations: int = 500
|
| 163 |
topology_remove_prob: float = 0.2
|
|
|
|
| 165 |
|
| 166 |
sensitivity_samples: int = 1000
|
| 167 |
|
|
|
|
|
|
|
| 168 |
depth_max_collapse: int = 3
|
| 169 |
fanin_target: int = 4
|
| 170 |
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
@dataclass
|
| 173 |
class CircuitSpec:
|
|
|
|
| 248 |
}
|
| 249 |
self.layer_groups[depth].append(neuron_name)
|
| 250 |
|
| 251 |
+
self._infer_depth_from_shapes()
|
| 252 |
self._identify_outputs()
|
| 253 |
|
| 254 |
def _estimate_depth(self, name: str) -> int:
|
|
|
|
| 274 |
|
| 275 |
return depth
|
| 276 |
|
| 277 |
+
def _infer_depth_from_shapes(self):
|
| 278 |
+
"""Second pass: infer depth from weight shapes when naming is ambiguous."""
|
| 279 |
+
neurons_at_depth_0 = []
|
| 280 |
+
neurons_needing_inference = []
|
| 281 |
+
|
| 282 |
+
for name, info in self.neurons.items():
|
| 283 |
+
input_size = info.get('input_size', 0)
|
| 284 |
+
if input_size == self.n_inputs:
|
| 285 |
+
info['depth'] = 0
|
| 286 |
+
info['input_source'] = 'raw'
|
| 287 |
+
info['input_neurons'] = []
|
| 288 |
+
neurons_at_depth_0.append(name)
|
| 289 |
+
elif input_size > 0 and input_size != self.n_inputs:
|
| 290 |
+
neurons_needing_inference.append((name, input_size))
|
| 291 |
+
|
| 292 |
+
for name, input_size in neurons_needing_inference:
|
| 293 |
+
if input_size == len(neurons_at_depth_0):
|
| 294 |
+
self.neurons[name]['depth'] = 1
|
| 295 |
+
self.neurons[name]['input_source'] = 'neurons'
|
| 296 |
+
self.neurons[name]['input_neurons'] = sorted(neurons_at_depth_0)
|
| 297 |
+
elif input_size < len(neurons_at_depth_0) and input_size > 0:
|
| 298 |
+
candidates = sorted(neurons_at_depth_0)[:input_size]
|
| 299 |
+
self.neurons[name]['depth'] = 1
|
| 300 |
+
self.neurons[name]['input_source'] = 'neurons'
|
| 301 |
+
self.neurons[name]['input_neurons'] = candidates
|
| 302 |
+
|
| 303 |
+
self.layer_groups = defaultdict(list)
|
| 304 |
+
for name, info in self.neurons.items():
|
| 305 |
+
self.layer_groups[info['depth']].append(name)
|
| 306 |
+
|
| 307 |
def _identify_outputs(self):
|
| 308 |
"""Identify which neurons are outputs based on n_outputs and topology."""
|
| 309 |
|
|
|
|
| 417 |
return outputs
|
| 418 |
|
| 419 |
def _get_neuron_input(self, neuron_name: str, activations: Dict, raw_input: torch.Tensor, expected_size: int) -> torch.Tensor:
|
| 420 |
+
"""Determine input for a neuron based on topology inference or naming conventions."""
|
| 421 |
+
info = self.neurons.get(neuron_name, {})
|
| 422 |
+
|
| 423 |
+
if info.get('input_source') == 'raw' or expected_size == self.n_inputs or expected_size == raw_input.shape[-1]:
|
| 424 |
return raw_input
|
| 425 |
|
| 426 |
+
if info.get('input_source') == 'neurons' and info.get('input_neurons'):
|
| 427 |
+
input_neurons = info['input_neurons']
|
| 428 |
+
vals = [activations[n] for n in input_neurons if n in activations]
|
| 429 |
+
if len(vals) == len(input_neurons):
|
| 430 |
+
return torch.stack(vals, dim=-1)
|
| 431 |
|
| 432 |
if 'layer2' in neuron_name:
|
| 433 |
base = neuron_name.replace('.layer2', '')
|
|
|
|
| 1083 |
if original_fitness < 0.999 and self.cfg.verbose:
|
| 1084 |
print(f" [EVAL ERROR] Native eval fitness={original_fitness:.4f}")
|
| 1085 |
|
| 1086 |
+
if self.batched_ready and not self.use_native_eval:
|
| 1087 |
+
try:
|
| 1088 |
+
test_vecs = []
|
| 1089 |
+
base = self.circuit.base_vector.clone()
|
| 1090 |
+
test_vecs.append(base)
|
| 1091 |
+
for _ in range(3):
|
| 1092 |
+
perturbed = base.clone()
|
| 1093 |
+
mask = torch.rand_like(perturbed) < 0.3
|
| 1094 |
+
perturbed[mask] = torch.randint(-2, 3, (mask.sum().item(),), device=self.device, dtype=torch.float32)
|
| 1095 |
+
test_vecs.append(perturbed)
|
| 1096 |
+
|
| 1097 |
+
test_pop = torch.stack(test_vecs)
|
| 1098 |
+
single_results = torch.tensor([
|
| 1099 |
+
self.evaluate_single(self.circuit.vector_to_weights(v)) for v in test_vecs
|
| 1100 |
+
], device=self.device)
|
| 1101 |
+
|
| 1102 |
+
batched_results = self._evaluate_batched(test_pop)
|
| 1103 |
+
|
| 1104 |
+
if not torch.allclose(single_results, batched_results, atol=0.01):
|
| 1105 |
+
if self.cfg.verbose:
|
| 1106 |
+
print(f" [EVAL WARNING] Batched eval mismatch (single={single_results.tolist()}, batched={batched_results.tolist()})")
|
| 1107 |
+
print(f" [EVAL WARNING] Falling back to native eval")
|
| 1108 |
+
self.use_native_eval = self.circuit.has_native
|
| 1109 |
+
self.batched_ready = False
|
| 1110 |
+
except Exception as e:
|
| 1111 |
+
if self.cfg.verbose:
|
| 1112 |
+
print(f" [EVAL WARNING] Batched eval failed ({e}), falling back to native eval")
|
| 1113 |
+
self.use_native_eval = self.circuit.has_native
|
| 1114 |
+
self.batched_ready = False
|
| 1115 |
+
|
| 1116 |
def _setup_vmap(self):
|
| 1117 |
"""Setup vmap-based parallel evaluation."""
|
| 1118 |
try:
|
|
|
|
| 1178 |
|
| 1179 |
input_size = w_shape[-1] if w_shape else 0
|
| 1180 |
|
| 1181 |
+
input_source = info.get('input_source', 'raw')
|
| 1182 |
+
input_neurons = info.get('input_neurons', [])
|
| 1183 |
+
|
| 1184 |
+
if input_source == 'raw' and input_size != self.n_inputs:
|
| 1185 |
+
if input_size == self.n_inputs:
|
| 1186 |
+
input_source = 'raw'
|
| 1187 |
+
elif 'layer2' in neuron_name:
|
| 1188 |
+
base = neuron_name.replace('.layer2', '')
|
| 1189 |
+
or_key = f'{base}.layer1.or'
|
| 1190 |
+
nand_key = f'{base}.layer1.nand'
|
| 1191 |
+
if or_key in graph.neurons and nand_key in graph.neurons:
|
| 1192 |
+
input_source = 'neurons'
|
| 1193 |
+
input_neurons = [or_key, nand_key]
|
| 1194 |
+
elif 'xor_final' in neuron_name:
|
| 1195 |
+
prefix = neuron_name.split('.xor_final')[0]
|
| 1196 |
+
candidates = [n for n in graph.neurons if n.startswith(prefix) and 'xor_' in n and 'final' not in n and 'layer2' in n]
|
| 1197 |
+
if len(candidates) >= 2:
|
| 1198 |
+
input_source = 'neurons'
|
| 1199 |
+
input_neurons = sorted(candidates)[-2:]
|
| 1200 |
|
| 1201 |
self.neuron_eval_order.append(neuron_name)
|
| 1202 |
self.neuron_weight_slices[neuron_name] = {
|
|
|
|
| 1892 |
)
|
| 1893 |
|
| 1894 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1895 |
def prune_topology(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 1896 |
"""Topology search - remove connection groups."""
|
| 1897 |
start = time.perf_counter()
|
|
|
|
| 1960 |
)
|
| 1961 |
|
| 1962 |
|
| 1963 |
+
def prune_structural(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 1964 |
+
"""Structural pruning - remove entire rows/columns of weight matrices."""
|
| 1965 |
start = time.perf_counter()
|
| 1966 |
weights = circuit.clone_weights()
|
| 1967 |
original = circuit.stats(weights)
|
| 1968 |
|
| 1969 |
if cfg.verbose:
|
| 1970 |
+
print(f" Structural pruning (rows/columns)...")
|
| 1971 |
|
| 1972 |
removed = 0
|
| 1973 |
|
|
|
|
| 2005 |
print(f" Removed {removed} rows/columns")
|
| 2006 |
|
| 2007 |
return PruneResult(
|
| 2008 |
+
method='structural',
|
| 2009 |
original_stats=original,
|
| 2010 |
final_stats=circuit.stats(weights),
|
| 2011 |
final_weights=weights,
|
|
|
|
| 2147 |
)
|
| 2148 |
|
| 2149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2150 |
def prune_depth(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2151 |
"""
|
| 2152 |
Depth reduction - attempt to collapse consecutive layers.
|
|
|
|
| 2565 |
yield tuple(signed)
|
| 2566 |
|
| 2567 |
|
| 2568 |
+
def prune_exhaustive_mag(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2569 |
"""
|
| 2570 |
+
Exhaustive search by magnitude level - finds provably minimum-magnitude solutions.
|
| 2571 |
|
| 2572 |
Searches magnitude 0, then 1, then 2, ... until valid solutions found.
|
| 2573 |
Returns ALL valid solutions at the minimum magnitude (to discover families).
|
|
|
|
| 2699 |
print(f" - Time: {elapsed:.1f}s")
|
| 2700 |
|
| 2701 |
return PruneResult(
|
| 2702 |
+
method='exhaustive_mag',
|
| 2703 |
original_stats=original,
|
| 2704 |
final_stats=circuit.stats(best_weights),
|
| 2705 |
final_weights=best_weights,
|
|
|
|
| 2714 |
)
|
| 2715 |
|
| 2716 |
|
| 2717 |
+
def _configs_with_k_nonzeros(k: int, n_params: int, max_weight: int):
|
| 2718 |
"""
|
| 2719 |
+
Generate all n_params-length configs with exactly k nonzero values.
|
| 2720 |
+
Nonzero values range from -max_weight to +max_weight (excluding 0).
|
| 2721 |
"""
|
| 2722 |
+
if k > n_params or k < 0:
|
| 2723 |
+
return
|
| 2724 |
+
|
| 2725 |
+
nonzero_vals = list(range(-max_weight, 0)) + list(range(1, max_weight + 1))
|
| 2726 |
+
|
| 2727 |
+
for positions in combinations(range(n_params), k):
|
| 2728 |
+
position_set = set(positions)
|
| 2729 |
+
for vals in product(nonzero_vals, repeat=k):
|
| 2730 |
+
config = [0] * n_params
|
| 2731 |
+
for pos, val in zip(positions, vals):
|
| 2732 |
+
config[pos] = val
|
| 2733 |
+
yield tuple(config)
|
| 2734 |
+
|
| 2735 |
+
|
| 2736 |
+
def prune_exhaustive_sparse(circuit: AdaptiveCircuit, evaluator: BatchedEvaluator, cfg: Config) -> PruneResult:
|
| 2737 |
+
"""
|
| 2738 |
+
Exhaustive search by sparsity level - finds provably maximum-sparsity solutions.
|
| 2739 |
+
|
| 2740 |
+
Searches from 1 nonzero, then 2, then 3, ... until valid solutions found.
|
| 2741 |
+
Returns ALL valid solutions at the minimum nonzero count (to discover families).
|
| 2742 |
+
|
| 2743 |
+
Useful for hardware where connection count matters more than weight magnitude.
|
| 2744 |
+
"""
|
| 2745 |
+
start = time.perf_counter()
|
| 2746 |
+
original = circuit.stats()
|
| 2747 |
+
|
| 2748 |
+
n_params = original['total']
|
| 2749 |
+
original_nonzeros = original['nonzero']
|
| 2750 |
+
max_weight = cfg.sparse_max_weight
|
| 2751 |
+
|
| 2752 |
+
if n_params > cfg.exhaustive_max_params:
|
| 2753 |
+
if cfg.verbose:
|
| 2754 |
+
print(f" [SPARSE] Skipping: {n_params} params exceeds max {cfg.exhaustive_max_params}")
|
| 2755 |
+
return PruneResult(
|
| 2756 |
+
method='exhaustive_sparse',
|
| 2757 |
+
original_stats=original,
|
| 2758 |
+
final_stats=original,
|
| 2759 |
+
final_weights=circuit.clone_weights(),
|
| 2760 |
+
fitness=evaluator.evaluate_single(circuit.weights),
|
| 2761 |
+
time_seconds=time.perf_counter() - start,
|
| 2762 |
+
metadata={'skipped': True, 'reason': 'too_many_params'}
|
| 2763 |
+
)
|
| 2764 |
+
|
| 2765 |
+
if cfg.verbose:
|
| 2766 |
+
print(f" [SPARSE] Parameters: {n_params}")
|
| 2767 |
+
print(f" [SPARSE] Original nonzeros: {original_nonzeros}")
|
| 2768 |
+
print(f" [SPARSE] Max weight magnitude: {max_weight}")
|
| 2769 |
+
print(f" [SPARSE] Searching by nonzero count (1, 2, 3, ...)")
|
| 2770 |
+
|
| 2771 |
+
weight_keys = list(circuit.weights.keys())
|
| 2772 |
+
weight_shapes = {k: circuit.weights[k].shape for k in weight_keys}
|
| 2773 |
+
weight_sizes = {k: circuit.weights[k].numel() for k in weight_keys}
|
| 2774 |
+
|
| 2775 |
+
def vector_to_weights(vec):
|
| 2776 |
+
weights = {}
|
| 2777 |
+
idx = 0
|
| 2778 |
+
for k in weight_keys:
|
| 2779 |
+
size = weight_sizes[k]
|
| 2780 |
+
weights[k] = torch.tensor(vec[idx:idx+size], dtype=torch.float32, device=cfg.device).view(weight_shapes[k])
|
| 2781 |
+
idx += size
|
| 2782 |
+
return weights
|
| 2783 |
+
|
| 2784 |
+
total_tested = 0
|
| 2785 |
+
all_solutions = []
|
| 2786 |
+
optimal_nonzeros = None
|
| 2787 |
+
|
| 2788 |
+
for n_nonzero in range(1, n_params + 1):
|
| 2789 |
+
nz_start = time.perf_counter()
|
| 2790 |
+
|
| 2791 |
+
n_positions = math.comb(n_params, n_nonzero)
|
| 2792 |
+
n_value_combos = (2 * max_weight) ** n_nonzero
|
| 2793 |
+
n_configs = n_positions * n_value_combos
|
| 2794 |
+
|
| 2795 |
+
if cfg.verbose:
|
| 2796 |
+
print(f" Nonzeros {n_nonzero}: {n_configs:,} configurations...", end=" ", flush=True)
|
| 2797 |
+
|
| 2798 |
+
valid_at_nz = []
|
| 2799 |
+
|
| 2800 |
+
batch_configs = []
|
| 2801 |
+
batch_size = min(100000, n_configs)
|
| 2802 |
+
|
| 2803 |
+
for config in _configs_with_k_nonzeros(n_nonzero, n_params, max_weight):
|
| 2804 |
+
batch_configs.append(config)
|
| 2805 |
|
| 2806 |
+
if len(batch_configs) >= batch_size:
|
| 2807 |
+
population = torch.tensor(batch_configs, dtype=torch.float32, device=cfg.device)
|
|
|
|
|
|
|
|
|
|
| 2808 |
|
| 2809 |
+
try:
|
| 2810 |
+
fitness_batch = evaluator.evaluate_population(population)
|
| 2811 |
+
except:
|
| 2812 |
+
fitness_batch = torch.tensor([
|
| 2813 |
+
evaluator.evaluate_single(vector_to_weights(c))
|
| 2814 |
+
for c in batch_configs
|
| 2815 |
+
], device=cfg.device)
|
| 2816 |
|
| 2817 |
+
valid_mask = fitness_batch >= cfg.fitness_threshold
|
| 2818 |
+
for i, is_valid in enumerate(valid_mask.tolist()):
|
| 2819 |
+
if is_valid:
|
| 2820 |
+
valid_at_nz.append(batch_configs[i])
|
| 2821 |
+
|
| 2822 |
+
total_tested += len(batch_configs)
|
| 2823 |
+
batch_configs = []
|
| 2824 |
+
|
| 2825 |
+
if batch_configs:
|
| 2826 |
+
population = torch.tensor(batch_configs, dtype=torch.float32, device=cfg.device)
|
| 2827 |
+
|
| 2828 |
+
try:
|
| 2829 |
+
fitness_batch = evaluator.evaluate_population(population)
|
| 2830 |
+
except:
|
| 2831 |
+
fitness_batch = torch.tensor([
|
| 2832 |
+
evaluator.evaluate_single(vector_to_weights(c))
|
| 2833 |
+
for c in batch_configs
|
| 2834 |
+
], device=cfg.device)
|
| 2835 |
+
|
| 2836 |
+
valid_mask = fitness_batch >= cfg.fitness_threshold
|
| 2837 |
+
for i, is_valid in enumerate(valid_mask.tolist()):
|
| 2838 |
+
if is_valid:
|
| 2839 |
+
valid_at_nz.append(batch_configs[i])
|
| 2840 |
+
|
| 2841 |
+
total_tested += len(batch_configs)
|
| 2842 |
+
|
| 2843 |
+
nz_time = time.perf_counter() - nz_start
|
| 2844 |
+
|
| 2845 |
+
if valid_at_nz:
|
| 2846 |
+
if cfg.verbose:
|
| 2847 |
+
print(f"FOUND {len(valid_at_nz)} solutions! ({nz_time:.2f}s)")
|
| 2848 |
+
|
| 2849 |
+
optimal_nonzeros = n_nonzero
|
| 2850 |
+
all_solutions = valid_at_nz
|
| 2851 |
+
|
| 2852 |
+
if cfg.verbose:
|
| 2853 |
+
print(f" [SPARSE] Optimal nonzeros: {optimal_nonzeros}")
|
| 2854 |
+
print(f" [SPARSE] Solutions found: {len(all_solutions)}")
|
| 2855 |
+
print(f" [SPARSE] Solution analysis:")
|
| 2856 |
+
print(f" {'#':<3} {'NZ':<4} {'Mag':<6} {'Max|w|':<7} {'Weights'}")
|
| 2857 |
+
print(f" {'-'*60}")
|
| 2858 |
+
for i, sol in enumerate(all_solutions[:20]):
|
| 2859 |
+
nz = sum(1 for v in sol if v != 0)
|
| 2860 |
+
mag = sum(abs(v) for v in sol)
|
| 2861 |
+
max_w = max(abs(v) for v in sol) if any(v != 0 for v in sol) else 0
|
| 2862 |
+
print(f" {i+1:<3} {nz:<4} {mag:<6} {max_w:<7} {sol}")
|
| 2863 |
+
if len(all_solutions) > 20:
|
| 2864 |
+
print(f" ... and {len(all_solutions) - 20} more")
|
| 2865 |
|
| 2866 |
+
break
|
| 2867 |
+
else:
|
| 2868 |
+
if cfg.verbose:
|
| 2869 |
+
print(f"none ({nz_time:.2f}s)")
|
| 2870 |
|
| 2871 |
+
elapsed = time.perf_counter() - start
|
| 2872 |
|
| 2873 |
+
if all_solutions:
|
| 2874 |
+
best_combo = min(all_solutions, key=lambda x: sum(abs(v) for v in x))
|
| 2875 |
+
best_weights = vector_to_weights(best_combo)
|
| 2876 |
+
best_fitness = evaluator.evaluate_single(best_weights)
|
| 2877 |
+
else:
|
| 2878 |
+
best_weights = circuit.clone_weights()
|
| 2879 |
+
best_fitness = evaluator.evaluate_single(best_weights)
|
| 2880 |
+
optimal_nonzeros = original_nonzeros
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2881 |
|
| 2882 |
+
if cfg.verbose:
|
| 2883 |
+
final_stats = circuit.stats(best_weights)
|
| 2884 |
+
print(f" [SPARSE COMPLETE]")
|
| 2885 |
+
print(f" - Configurations tested: {total_tested:,}")
|
| 2886 |
+
print(f" - Optimal nonzeros: {optimal_nonzeros} (original: {original_nonzeros})")
|
| 2887 |
+
print(f" - Total solutions at optimal: {len(all_solutions)}")
|
| 2888 |
+
print(f" - Sparsity: {(1 - optimal_nonzeros/n_params)*100:.1f}%")
|
| 2889 |
+
print(f" - Time: {elapsed:.1f}s")
|
| 2890 |
|
| 2891 |
+
return PruneResult(
|
| 2892 |
+
method='exhaustive_sparse',
|
| 2893 |
+
original_stats=original,
|
| 2894 |
+
final_stats=circuit.stats(best_weights),
|
| 2895 |
+
final_weights=best_weights,
|
| 2896 |
+
fitness=best_fitness,
|
| 2897 |
+
time_seconds=elapsed,
|
| 2898 |
+
metadata={
|
| 2899 |
+
'optimal_nonzeros': optimal_nonzeros,
|
| 2900 |
+
'total_tested': total_tested,
|
| 2901 |
+
'solutions_count': len(all_solutions),
|
| 2902 |
+
'all_solutions': all_solutions[:100]
|
| 2903 |
+
}
|
| 2904 |
+
)
|
| 2905 |
|
| 2906 |
|
| 2907 |
def run_all_methods(circuit: AdaptiveCircuit, cfg: Config) -> Dict[str, PruneResult]:
|
|
|
|
| 2938 |
('magnitude', cfg.run_magnitude, lambda: prune_magnitude(circuit, evaluator, cfg)),
|
| 2939 |
('zero', cfg.run_zero, lambda: prune_zero(circuit, evaluator, cfg)),
|
| 2940 |
('quantize', cfg.run_quantize, lambda: prune_quantize(circuit, evaluator, cfg)),
|
| 2941 |
+
('structural', cfg.run_structural, lambda: prune_structural(circuit, evaluator, cfg)),
|
|
|
|
| 2942 |
('topology', cfg.run_topology, lambda: prune_topology(circuit, evaluator, cfg)),
|
|
|
|
| 2943 |
('sensitivity', cfg.run_sensitivity, lambda: prune_sensitivity(circuit, evaluator, cfg)),
|
| 2944 |
('weight_sharing', cfg.run_weight_sharing, lambda: prune_weight_sharing(circuit, evaluator, cfg)),
|
| 2945 |
('depth', cfg.run_depth, lambda: prune_depth(circuit, evaluator, cfg)),
|
| 2946 |
('gate_subst', cfg.run_gate_subst, lambda: prune_gate_substitution(circuit, evaluator, cfg)),
|
| 2947 |
('symmetry', cfg.run_symmetry, lambda: prune_symmetry(circuit, evaluator, cfg)),
|
| 2948 |
('fanin', cfg.run_fanin, lambda: prune_fanin(circuit, evaluator, cfg)),
|
| 2949 |
+
('exhaustive_mag', cfg.run_exhaustive_mag, lambda: prune_exhaustive_mag(circuit, evaluator, cfg)),
|
| 2950 |
+
('exhaustive_sparse', cfg.run_exhaustive_sparse, lambda: prune_exhaustive_sparse(circuit, evaluator, cfg)),
|
| 2951 |
('evolutionary', cfg.run_evolutionary, lambda: prune_evolutionary(circuit, evaluator, cfg)),
|
| 2952 |
('annealing', cfg.run_annealing, lambda: prune_annealing(circuit, evaluator, cfg)),
|
|
|
|
| 2953 |
]
|
| 2954 |
|
| 2955 |
enabled_methods = [(name, fn) for name, enabled, fn in methods if enabled]
|
|
|
|
| 2995 |
reduction = 1 - best_mag / original['magnitude']
|
| 2996 |
print(f"\n BEST: {best_method} ({reduction * 100:.1f}% magnitude reduction)")
|
| 2997 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2998 |
return results
|
| 2999 |
|
| 3000 |
|
|
|
|
| 3119 |
parser.add_argument('--sa-iters', type=int, default=50000, help='Simulated annealing iterations')
|
| 3120 |
parser.add_argument('--sa-chains', type=int, default=0, help='Parallel SA chains (0=auto)')
|
| 3121 |
parser.add_argument('--vram-target', type=float, default=0.75)
|
|
|
|
| 3122 |
parser.add_argument('--fanin-target', type=int, default=4)
|
| 3123 |
+
parser.add_argument('--sparse-max-weight', type=int, default=3, help='Max weight magnitude for sparse search')
|
| 3124 |
+
parser.add_argument('--exhaustive-max-params', type=int, default=12, help='Max params for exhaustive search')
|
| 3125 |
|
| 3126 |
args = parser.parse_args()
|
| 3127 |
|
|
|
|
| 3143 |
evo_generations=args.evo_gens,
|
| 3144 |
annealing_iterations=args.sa_iters,
|
| 3145 |
annealing_parallel_chains=args.sa_chains,
|
| 3146 |
+
fanin_target=args.fanin_target,
|
| 3147 |
+
sparse_max_weight=args.sparse_max_weight,
|
| 3148 |
+
exhaustive_max_params=args.exhaustive_max_params
|
| 3149 |
)
|
| 3150 |
|
| 3151 |
if args.methods:
|
| 3152 |
all_methods = ['magnitude', 'zero', 'quantize', 'evolutionary', 'annealing',
|
| 3153 |
+
'structural', 'topology', 'sensitivity', 'weight_sharing',
|
| 3154 |
+
'depth', 'gate_subst', 'symmetry', 'fanin',
|
| 3155 |
+
'exhaustive_mag', 'exhaustive_sparse']
|
| 3156 |
for m in all_methods:
|
| 3157 |
setattr(cfg, f'run_{m}', False)
|
| 3158 |
|
|
|
|
| 3164 |
'quant': 'quantize', 'quantize': 'quantize',
|
| 3165 |
'evo': 'evolutionary', 'evolutionary': 'evolutionary',
|
| 3166 |
'anneal': 'annealing', 'annealing': 'annealing', 'sa': 'annealing',
|
| 3167 |
+
'structural': 'structural', 'struct': 'structural',
|
|
|
|
| 3168 |
'topo': 'topology', 'topology': 'topology',
|
|
|
|
| 3169 |
'sens': 'sensitivity', 'sensitivity': 'sensitivity',
|
| 3170 |
'share': 'weight_sharing', 'weight_sharing': 'weight_sharing', 'sharing': 'weight_sharing',
|
|
|
|
|
|
|
| 3171 |
'depth': 'depth',
|
| 3172 |
'gate': 'gate_subst', 'gate_subst': 'gate_subst', 'subst': 'gate_subst',
|
| 3173 |
'sym': 'symmetry', 'symmetry': 'symmetry',
|
| 3174 |
'fanin': 'fanin', 'fan': 'fanin',
|
| 3175 |
+
'exhaustive_mag': 'exhaustive_mag', 'exh_mag': 'exhaustive_mag', 'exh': 'exhaustive_mag', 'brute': 'exhaustive_mag',
|
| 3176 |
+
'exhaustive_sparse': 'exhaustive_sparse', 'exh_sparse': 'exhaustive_sparse', 'sparse': 'exhaustive_sparse'
|
| 3177 |
}
|
| 3178 |
if m in method_map:
|
| 3179 |
setattr(cfg, f'run_{method_map[m]}', True)
|
|
|
|
| 3233 |
print(" python prune.py threshold-hamming74decoder --methods evo")
|
| 3234 |
print(" python prune.py threshold-xor --methods evo --evo-pop 500000 --evo-gens 5000")
|
| 3235 |
print("")
|
| 3236 |
+
print(" # Exhaustive search (provably optimal):")
|
| 3237 |
+
print(" python prune.py threshold-xor --methods exh_mag # minimize magnitude")
|
| 3238 |
+
print(" python prune.py threshold-xor --methods exh_sparse # minimize nonzeros")
|
| 3239 |
+
print(" python prune.py threshold-mux --methods sparse --sparse-max-weight 2")
|
| 3240 |
+
print("")
|
| 3241 |
print(" # Pipeline mode (chained, each stage feeds into next):")
|
| 3242 |
print(" python prune.py threshold-hamming74decoder --pipeline evo,mag,zero,quant --save")
|
| 3243 |
print(" python prune.py threshold-xor --pipeline anneal,mag,zero --sa-iters 100000")
|