CharlesCNorton commited on
Commit
90f3f79
·
1 Parent(s): 1e96b5b

Remove eval/ folder, move prune_weights.py to root

Browse files

- Delete legacy eval/iron_eval.py, eval/comprehensive_eval.py
- Move prune_weights.py to root (imports from eval.py)
- Update README TODO

README.md CHANGED
@@ -460,9 +460,11 @@ The interface generalizes to **all** 65,536 8-bit additions once trained—no me
460
 
461
  | File | Description |
462
  |------|-------------|
463
- | `neural_computer.safetensors` | 6,296 tensors, 8,267,667 parameters |
464
- | `cpu/core.py` | CPU state, reference cycle, threshold runtime |
465
- | `eval/iron_eval.py` | Comprehensive test suite |
 
 
466
 
467
  ---
468
 
@@ -490,8 +492,8 @@ The interface generalizes to **all** 65,536 8-bit additions once trained—no me
490
  - [x] Extract shared utilities: `heaviside()`, `load_model()`, `create_population()`
491
  - [x] Unified evaluation with both batched speed and per-circuit reporting
492
  - [x] Read signal registry from safetensors metadata instead of routing.json
493
- - [ ] Remove `eval/` folder (legacy scripts, now superseded by root `eval.py`)
494
- - [ ] Integrate pruning into `eval.py` or update `prune_weights.py` to import from `eval.py`
495
 
496
  ---
497
 
 
460
 
461
  | File | Description |
462
  |------|-------------|
463
+ | `neural_computer.safetensors` | 9,429 tensors, 8,286,614 parameters |
464
+ | `threshold_cpu.py` | CPU state, reference cycle, threshold runtime |
465
+ | `eval.py` | Unified evaluation suite (5,282 tests, GPU-batched) |
466
+ | `build.py` | Build tools for memory circuits and .inputs tensors |
467
+ | `prune_weights.py` | Weight magnitude pruning |
468
 
469
  ---
470
 
 
492
  - [x] Extract shared utilities: `heaviside()`, `load_model()`, `create_population()`
493
  - [x] Unified evaluation with both batched speed and per-circuit reporting
494
  - [x] Read signal registry from safetensors metadata instead of routing.json
495
+ - [x] Remove `eval/` folder (legacy scripts, now superseded by root `eval.py`)
496
+ - [x] Update `prune_weights.py` to import from `eval.py`
497
 
498
  ---
499
 
eval/comprehensive_eval.py DELETED
The diff for this file is too large to render. See raw diff
 
eval/iron_eval.py DELETED
The diff for this file is too large to render. See raw diff
 
eval/prune_weights.py → prune_weights.py RENAMED
@@ -1,481 +1,481 @@
1
- """
2
- BATCHED WEIGHT PRUNING (GPU-optimized)
3
- ======================================
4
- Phase 1: Batch eval all candidates in parallel
5
- Phase 2: Apply all successes at once, binary search if conflicts
6
- """
7
-
8
- import torch
9
- import time
10
- import argparse
11
- from safetensors.torch import save_file
12
- from iron_eval import BatchedFitnessEvaluator, create_population, load_model
13
-
14
- torch.manual_seed(0)
15
-
16
-
17
- def format_time(seconds):
18
- if seconds < 60:
19
- return f"{seconds:.1f}s"
20
- elif seconds < 3600:
21
- return f"{seconds/60:.1f}m"
22
- else:
23
- return f"{seconds/3600:.1f}h"
24
-
25
-
26
- def format_eta(elapsed, done, total):
27
- if done == 0:
28
- return "calculating..."
29
- rate = done / elapsed
30
- remaining = (total - done) / rate
31
- return format_time(remaining)
32
-
33
-
34
- def apply_reductions(model, reductions):
35
- """Apply a list of (name, flat_idx, shape, old_val) reductions."""
36
- for name, flat_idx, shape, old_val in reductions:
37
- new_val = old_val - 1 if old_val > 0 else old_val + 1
38
- flat = model[name].flatten()
39
- if flat[flat_idx].item() == old_val:
40
- flat[flat_idx] = new_val
41
- model[name] = flat.view(shape)
42
-
43
-
44
- def revert_reductions(model, reductions):
45
- """Revert a list of reductions."""
46
- for name, flat_idx, shape, old_val in reductions:
47
- flat = model[name].flatten()
48
- new_val = old_val - 1 if old_val > 0 else old_val + 1
49
- if flat[flat_idx].item() == new_val:
50
- flat[flat_idx] = old_val
51
- model[name] = flat.view(shape)
52
-
53
-
54
- def check_fitness(model, evaluator, device):
55
- """Check model fitness."""
56
- torch.manual_seed(0)
57
- pop = create_population(model, 1, device)
58
- return evaluator.evaluate(pop, debug=False)[0].item()
59
-
60
-
61
- def sequential_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
62
- """
63
- Sequential fallback - tests and applies reductions one at a time.
64
- Slower but guarantees no interaction bugs.
65
- """
66
- accepted = []
67
- for i, (name, flat_idx, shape, old_val) in enumerate(candidates):
68
- apply_reductions(model, [(name, flat_idx, shape, old_val)])
69
- fitness = check_fitness(model, evaluator, device)
70
- if fitness >= 0.9999:
71
- accepted.append((name, flat_idx, shape, old_val))
72
- if (i + 1) % 50 == 0:
73
- current_mag = sum(t.abs().sum().item() for t in model.values())
74
- reduction_pct = 100 * (1 - current_mag / base_magnitude)
75
- print(f" Sequential: {len(accepted)}/{i+1} accepted | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
76
- else:
77
- revert_reductions(model, [(name, flat_idx, shape, old_val)])
78
- return accepted
79
-
80
-
81
- def batched_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
82
- """
83
- Batched binary search - evaluates multiple branches in parallel.
84
- Uses BFS instead of DFS to maximize batching opportunities.
85
- Verifies cumulative effect after each batch to prevent interaction bugs.
86
- """
87
- if len(candidates) == 0:
88
- return []
89
-
90
- # First try all at once
91
- print(f" Trying {len(candidates)} reductions at once...")
92
- apply_reductions(model, candidates)
93
- fitness = check_fitness(model, evaluator, device)
94
-
95
- if fitness >= 0.9999:
96
- current_mag = sum(t.abs().sum().item() for t in model.values())
97
- reduction_pct = 100 * (1 - current_mag / base_magnitude)
98
- print(f" ALL {len(candidates)} OK | fitness={fitness:.6f} | "
99
- f"mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
100
- return candidates
101
-
102
- # Conflict - revert and use batched BFS
103
- revert_reductions(model, candidates)
104
- print(f" CONFLICT (fitness={fitness:.6f}), starting batched resolution...")
105
-
106
- accepted = []
107
- # Queue of (candidate_list, depth) to process
108
- pending = [(candidates, 0)]
109
-
110
- while pending:
111
- # Collect all pending groups for batch evaluation
112
- to_eval = []
113
- for group, depth in pending:
114
- if len(group) == 0:
115
- continue
116
- elif len(group) == 1:
117
- to_eval.append((group, depth, 'single'))
118
- else:
119
- to_eval.append((group, depth, 'group'))
120
-
121
- pending = []
122
-
123
- if not to_eval:
124
- break
125
-
126
- # Build batch: create model variants for each group
127
- batch_size = len(to_eval)
128
- print(f" Batch evaluating {batch_size} groups...")
129
-
130
- # Create population for batch eval
131
- pop = {}
132
- for name, tensor in model.items():
133
- pop[name] = tensor.unsqueeze(0).expand(batch_size, *tensor.shape).clone().to(device)
134
-
135
- # Apply each group's reductions to its population slot
136
- for idx, (group, depth, gtype) in enumerate(to_eval):
137
- for name, flat_idx, shape, old_val in group:
138
- new_val = old_val - 1 if old_val > 0 else old_val + 1
139
- flat_view = pop[name][idx].flatten()
140
- # Check if not already modified in base model
141
- base_val = model[name].flatten()[flat_idx].item()
142
- if base_val == old_val:
143
- flat_view[flat_idx] = new_val
144
-
145
- # Batch evaluate
146
- torch.manual_seed(0)
147
- fitnesses = evaluator.evaluate(pop, debug=False)
148
-
149
- # Process results - collect accepted groups first, then verify
150
- batch_accepted = []
151
- ok_count = 0
152
- conflict_count = 0
153
- fail_count = 0
154
-
155
- for idx, (group, depth, gtype) in enumerate(to_eval):
156
- fit = fitnesses[idx].item()
157
- indent = " " + " " * depth
158
-
159
- if fit >= 0.9999:
160
- batch_accepted.append((group, depth, indent))
161
- ok_count += len(group)
162
- else:
163
- if len(group) == 1:
164
- name, flat_idx, shape, old_val = group[0]
165
- print(f"{indent}[1/1] FAIL {name}[{flat_idx}] | fitness={fit:.6f}")
166
- fail_count += 1
167
- else:
168
- mid = len(group) // 2
169
- left = group[:mid]
170
- right = group[mid:]
171
- print(f"{indent}CONFLICT ({len(group)}) fitness={fit:.6f} -> split {len(left)}+{len(right)}")
172
- pending.append((left, depth + 1))
173
- pending.append((right, depth + 1))
174
- conflict_count += 1
175
-
176
- # Apply all batch-accepted reductions
177
- all_batch_reductions = []
178
- for group, depth, indent in batch_accepted:
179
- apply_reductions(model, group)
180
- all_batch_reductions.extend(group)
181
-
182
- # Verify cumulative effect
183
- if all_batch_reductions:
184
- verify_fitness = check_fitness(model, evaluator, device)
185
- if verify_fitness >= 0.9999:
186
- # All good - commit these reductions
187
- for group, depth, indent in batch_accepted:
188
- current_mag = sum(t.abs().sum().item() for t in model.values())
189
- reduction_pct = 100 * (1 - current_mag / base_magnitude)
190
- if len(group) == 1:
191
- name, flat_idx, shape, old_val = group[0]
192
- print(f"{indent}[1/1] OK {name}[{flat_idx}] | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
193
- else:
194
- print(f"{indent}ALL {len(group)} OK | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
195
- accepted.extend(all_batch_reductions)
196
- print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
197
- else:
198
- # Interaction bug detected - revert and use sequential fallback
199
- print(f" INTERACTION BUG detected (batch fitness={verify_fitness:.6f})")
200
- print(f" Reverting {len(all_batch_reductions)} reductions, falling back to sequential...")
201
- revert_reductions(model, all_batch_reductions)
202
-
203
- # Process each group sequentially
204
- seq_accepted = sequential_conflict_resolution(
205
- model, evaluator, device, all_batch_reductions, base_magnitude
206
- )
207
- accepted.extend(seq_accepted)
208
- print(f" Sequential fallback: {len(seq_accepted)}/{len(all_batch_reductions)} accepted")
209
- else:
210
- print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
211
-
212
- return accepted
213
-
214
-
215
- def prune_weights(
216
- passes: int = 10,
217
- batch_size: int = 5000,
218
- device: str = 'cuda',
219
- checkpoint_path: str = "D:/8bit-threshold-computer/pruned.safetensors"
220
- ):
221
- print("=" * 80)
222
- print(" BATCHED WEIGHT PRUNING (GPU-optimized)")
223
- print("=" * 80)
224
- print(f" Device: {device}")
225
- print(f" Batch size: {batch_size}")
226
- print(f" Max passes: {passes}")
227
- print("=" * 80)
228
-
229
- # Load model
230
- print("\n[1/4] LOADING MODEL...")
231
- load_start = time.perf_counter()
232
- model = load_model()
233
- load_time = time.perf_counter() - load_start
234
-
235
- n_params = sum(t.numel() for t in model.values())
236
- n_tensors = len(model)
237
- base_magnitude = sum(t.abs().sum().item() for t in model.values())
238
- base_max = max(t.abs().max().item() for t in model.values())
239
- nonzero_params = sum((t != 0).sum().item() for t in model.values())
240
-
241
- print(f" Loaded in {load_time:.2f}s")
242
- print(f" Tensors: {n_tensors}")
243
- print(f" Parameters: {n_params}")
244
- print(f" Non-zero parameters: {nonzero_params}")
245
- print(f" Total magnitude: {base_magnitude:.0f}")
246
- print(f" Max weight: {base_max:.0f}")
247
-
248
- # Initialize evaluator
249
- print("\n[2/4] INITIALIZING EVALUATOR...")
250
- eval_start = time.perf_counter()
251
- evaluator = BatchedFitnessEvaluator(device=device)
252
- eval_time = time.perf_counter() - eval_start
253
- print(f" Initialized in {eval_time:.2f}s")
254
-
255
- # Verify initial fitness
256
- print("\n[3/4] VERIFYING BASE MODEL...")
257
- initial_fitness = check_fitness(model, evaluator, device)
258
- print(f" Fitness: {initial_fitness:.6f}")
259
-
260
- if initial_fitness < 0.9999:
261
- print(f" ERROR: Base model fitness {initial_fitness:.6f} < 0.9999")
262
- return None
263
-
264
- print(f" STATUS: PASS")
265
-
266
- # Build parameter list
267
- print("\n[4/4] BUILDING PARAMETER INDEX...")
268
- param_list = []
269
- for name, tensor in model.items():
270
- flat = tensor.flatten()
271
- for i in range(len(flat)):
272
- param_list.append((name, i, tensor.shape))
273
- print(f" Indexed {len(param_list)} parameters")
274
-
275
- # Main pruning loop
276
- print("\n" + "=" * 80)
277
- print(" PRUNING STARTED")
278
- print("=" * 80)
279
-
280
- total_reductions = 0
281
- pruning_start = time.perf_counter()
282
-
283
- for pass_num in range(passes):
284
- torch.manual_seed(0)
285
- pass_start = time.perf_counter()
286
-
287
- print(f"\n{'='*80}")
288
- print(f" PASS {pass_num + 1}/{passes}")
289
- print(f"{'='*80}")
290
-
291
- # Count candidates
292
- candidates = []
293
- for name, idx, shape in param_list:
294
- flat = model[name].flatten()
295
- val = flat[idx].item()
296
- if val != 0:
297
- candidates.append((name, idx, shape, val))
298
-
299
- n_candidates = len(candidates)
300
- print(f"\n Candidates: {n_candidates} non-zero weights")
301
-
302
- if n_candidates == 0:
303
- print(f" No candidates remaining. Stopping.")
304
- break
305
-
306
- # Phase 1: Batch evaluation
307
- print(f"\n PHASE 1: Batch evaluation (testing each reduction independently)")
308
- print(f" " + "-" * 60)
309
- phase1_start = time.perf_counter()
310
- successful_candidates = []
311
- n_batches = (n_candidates + batch_size - 1) // batch_size
312
-
313
- for batch_idx, batch_start_idx in enumerate(range(0, n_candidates, batch_size)):
314
- batch = candidates[batch_start_idx:batch_start_idx + batch_size]
315
- batch_len = len(batch)
316
- batch_start_time = time.perf_counter()
317
-
318
- # Build population
319
- pop = {}
320
- for name, tensor in model.items():
321
- pop[name] = tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device)
322
-
323
- # Apply reductions
324
- for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
325
- new_val = old_val - 1 if old_val > 0 else old_val + 1
326
- flat_view = pop[name][pop_idx].flatten()
327
- flat_view[flat_idx] = new_val
328
-
329
- # Evaluate
330
- torch.manual_seed(0)
331
- if device == 'cuda':
332
- torch.cuda.synchronize()
333
- fitness = evaluator.evaluate(pop, debug=False)
334
- if device == 'cuda':
335
- torch.cuda.synchronize()
336
-
337
- # Collect successes
338
- batch_successes = 0
339
- for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
340
- if fitness[pop_idx].item() >= 0.9999:
341
- successful_candidates.append((name, flat_idx, shape, old_val))
342
- batch_successes += 1
343
-
344
- batch_time = time.perf_counter() - batch_start_time
345
- elapsed = time.perf_counter() - phase1_start
346
- done = batch_start_idx + batch_len
347
- eta = format_eta(elapsed, done, n_candidates)
348
- throughput = batch_len / batch_time
349
-
350
- print(f" Batch {batch_idx + 1}/{n_batches}: "
351
- f"{batch_successes}/{batch_len} passed ({100*batch_successes/batch_len:.1f}%) | "
352
- f"Total OK: {len(successful_candidates)} | "
353
- f"Progress: {done}/{n_candidates} ({100*done/n_candidates:.1f}%) | "
354
- f"Speed: {throughput:.0f}/s | "
355
- f"ETA: {eta}")
356
-
357
- phase1_time = time.perf_counter() - phase1_start
358
- print(f"\n Phase 1 complete: {len(successful_candidates)}/{n_candidates} candidates "
359
- f"({100*len(successful_candidates)/n_candidates:.1f}%) in {format_time(phase1_time)}")
360
-
361
- # Phase 2: Apply with conflict resolution
362
- if len(successful_candidates) == 0:
363
- print(f"\n No reductions possible. Stopping.")
364
- break
365
-
366
- print(f"\n PHASE 2: Apply reductions with conflict resolution")
367
- print(f" " + "-" * 60)
368
- phase2_start = time.perf_counter()
369
-
370
- accepted = batched_conflict_resolution(model, evaluator, device, successful_candidates, base_magnitude)
371
- pass_reductions = len(accepted)
372
-
373
- phase2_time = time.perf_counter() - phase2_start
374
- print(f"\n Phase 2 complete: {pass_reductions} reductions applied in {format_time(phase2_time)}")
375
-
376
- # Pass summary
377
- total_reductions += pass_reductions
378
- current_magnitude = sum(t.abs().sum().item() for t in model.values())
379
- current_nonzero = sum((t != 0).sum().item() for t in model.values())
380
- pass_time = time.perf_counter() - pass_start
381
- reduction_pct = 100 * (1 - current_magnitude / base_magnitude)
382
-
383
- print(f"\n PASS {pass_num + 1} SUMMARY:")
384
- print(f" Reductions this pass: {pass_reductions}")
385
- print(f" Total reductions: {total_reductions}")
386
- print(f" Current magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
387
- print(f" Current non-zero: {current_nonzero}")
388
- print(f" Pass time: {format_time(pass_time)}")
389
-
390
- # Verify after pass
391
- print(f"\n Verifying model integrity...")
392
- fitness = check_fitness(model, evaluator, device)
393
- print(f" Fitness: {fitness:.6f} {'PASS' if fitness >= 0.9999 else 'FAIL'}")
394
-
395
- # Save checkpoint after each pass
396
- checkpoint_name = checkpoint_path.replace('.safetensors', f'_pass{pass_num + 1}.safetensors')
397
- print(f"\n Saving checkpoint: {checkpoint_name}")
398
- save_file(model, checkpoint_name)
399
- print(f" Saved. Magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
400
-
401
- # Also save as "latest" for easy access
402
- latest_path = checkpoint_path.replace('.safetensors', '_latest.safetensors')
403
- save_file(model, latest_path)
404
- print(f" Also saved as: {latest_path}")
405
-
406
- if pass_reductions == 0:
407
- print(f"\n No reductions achieved. Stopping early.")
408
- break
409
-
410
- # Final summary
411
- pruning_time = time.perf_counter() - pruning_start
412
- final_magnitude = sum(t.abs().sum().item() for t in model.values())
413
- final_max = max(t.abs().max().item() for t in model.values())
414
- final_nonzero = sum((t != 0).sum().item() for t in model.values())
415
- reduction_pct = 100 * (1 - final_magnitude / base_magnitude)
416
-
417
- print("\n" + "=" * 80)
418
- print(" PRUNING COMPLETE")
419
- print("=" * 80)
420
- print(f"\n RESULTS:")
421
- print(f" Original magnitude: {base_magnitude:.0f}")
422
- print(f" Final magnitude: {final_magnitude:.0f}")
423
- print(f" Reduction: {reduction_pct:.2f}%")
424
- print(f" Total reductions: {total_reductions}")
425
- print(f" Original non-zero: {nonzero_params}")
426
- print(f" Final non-zero: {final_nonzero}")
427
- print(f" Zeros created: {nonzero_params - final_nonzero}")
428
- print(f" Max weight: {final_max:.0f}")
429
- print(f" Total time: {format_time(pruning_time)}")
430
-
431
- # Save
432
- print(f"\n SAVING to {checkpoint_path}...")
433
- save_file(model, checkpoint_path)
434
- print(f" Saved.")
435
-
436
- # Final verification
437
- print(f"\n FINAL VERIFICATION...")
438
- from safetensors import safe_open
439
- f = safe_open(checkpoint_path, framework='numpy')
440
- verify_model = {name: torch.tensor(f.get_tensor(name)).float() for name in f.keys()}
441
- verify_fitness = check_fitness(verify_model, evaluator, device)
442
- print(f" Fitness: {verify_fitness:.6f}")
443
-
444
- if verify_fitness >= 0.9999:
445
- print(f" STATUS: PASS")
446
- else:
447
- print(f" STATUS: FAIL - Model corrupted!")
448
-
449
- print("\n" + "=" * 80)
450
- return model
451
-
452
-
453
- MAX_BATCH_SIZE = 80000
454
-
455
- if __name__ == "__main__":
456
- parser = argparse.ArgumentParser(description='Batched Weight Pruning')
457
- parser.add_argument('--passes', type=int, default=10,
458
- help='Maximum pruning passes (default: 10)')
459
- parser.add_argument('--batch_size', type=int, default=80000,
460
- help=f'Batch size for parallel evaluation (default: 80000, max: {MAX_BATCH_SIZE})')
461
- parser.add_argument('--device', type=str, default='cuda',
462
- help='Device: cuda or cpu (default: cuda)')
463
- parser.add_argument('--output', type=str,
464
- default='D:/8bit-threshold-computer/pruned.safetensors',
465
- help='Output path')
466
- args = parser.parse_args()
467
-
468
- if args.batch_size > MAX_BATCH_SIZE:
469
- print(f"WARNING: batch_size {args.batch_size} exceeds maximum {MAX_BATCH_SIZE}. Clamping.")
470
- args.batch_size = MAX_BATCH_SIZE
471
-
472
- print(f"\nStarting at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
473
-
474
- prune_weights(
475
- passes=args.passes,
476
- batch_size=args.batch_size,
477
- device=args.device,
478
- checkpoint_path=args.output
479
- )
480
-
481
- print(f"\nFinished at {time.strftime('%Y-%m-%d %H:%M:%S')}")
 
1
+ """
2
+ BATCHED WEIGHT PRUNING (GPU-optimized)
3
+ ======================================
4
+ Phase 1: Batch eval all candidates in parallel
5
+ Phase 2: Apply all successes at once, binary search if conflicts
6
+ """
7
+
8
+ import torch
9
+ import time
10
+ import argparse
11
+ from safetensors.torch import save_file
12
+ from eval import BatchedFitnessEvaluator, create_population, load_model
13
+
14
+ torch.manual_seed(0)
15
+
16
+
17
+ def format_time(seconds):
18
+ if seconds < 60:
19
+ return f"{seconds:.1f}s"
20
+ elif seconds < 3600:
21
+ return f"{seconds/60:.1f}m"
22
+ else:
23
+ return f"{seconds/3600:.1f}h"
24
+
25
+
26
+ def format_eta(elapsed, done, total):
27
+ if done == 0:
28
+ return "calculating..."
29
+ rate = done / elapsed
30
+ remaining = (total - done) / rate
31
+ return format_time(remaining)
32
+
33
+
34
+ def apply_reductions(model, reductions):
35
+ """Apply a list of (name, flat_idx, shape, old_val) reductions."""
36
+ for name, flat_idx, shape, old_val in reductions:
37
+ new_val = old_val - 1 if old_val > 0 else old_val + 1
38
+ flat = model[name].flatten()
39
+ if flat[flat_idx].item() == old_val:
40
+ flat[flat_idx] = new_val
41
+ model[name] = flat.view(shape)
42
+
43
+
44
+ def revert_reductions(model, reductions):
45
+ """Revert a list of reductions."""
46
+ for name, flat_idx, shape, old_val in reductions:
47
+ flat = model[name].flatten()
48
+ new_val = old_val - 1 if old_val > 0 else old_val + 1
49
+ if flat[flat_idx].item() == new_val:
50
+ flat[flat_idx] = old_val
51
+ model[name] = flat.view(shape)
52
+
53
+
54
+ def check_fitness(model, evaluator, device):
55
+ """Check model fitness."""
56
+ torch.manual_seed(0)
57
+ pop = create_population(model, 1, device)
58
+ return evaluator.evaluate(pop, debug=False)[0].item()
59
+
60
+
61
+ def sequential_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
62
+ """
63
+ Sequential fallback - tests and applies reductions one at a time.
64
+ Slower but guarantees no interaction bugs.
65
+ """
66
+ accepted = []
67
+ for i, (name, flat_idx, shape, old_val) in enumerate(candidates):
68
+ apply_reductions(model, [(name, flat_idx, shape, old_val)])
69
+ fitness = check_fitness(model, evaluator, device)
70
+ if fitness >= 0.9999:
71
+ accepted.append((name, flat_idx, shape, old_val))
72
+ if (i + 1) % 50 == 0:
73
+ current_mag = sum(t.abs().sum().item() for t in model.values())
74
+ reduction_pct = 100 * (1 - current_mag / base_magnitude)
75
+ print(f" Sequential: {len(accepted)}/{i+1} accepted | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
76
+ else:
77
+ revert_reductions(model, [(name, flat_idx, shape, old_val)])
78
+ return accepted
79
+
80
+
81
+ def batched_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
82
+ """
83
+ Batched binary search - evaluates multiple branches in parallel.
84
+ Uses BFS instead of DFS to maximize batching opportunities.
85
+ Verifies cumulative effect after each batch to prevent interaction bugs.
86
+ """
87
+ if len(candidates) == 0:
88
+ return []
89
+
90
+ # First try all at once
91
+ print(f" Trying {len(candidates)} reductions at once...")
92
+ apply_reductions(model, candidates)
93
+ fitness = check_fitness(model, evaluator, device)
94
+
95
+ if fitness >= 0.9999:
96
+ current_mag = sum(t.abs().sum().item() for t in model.values())
97
+ reduction_pct = 100 * (1 - current_mag / base_magnitude)
98
+ print(f" ALL {len(candidates)} OK | fitness={fitness:.6f} | "
99
+ f"mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
100
+ return candidates
101
+
102
+ # Conflict - revert and use batched BFS
103
+ revert_reductions(model, candidates)
104
+ print(f" CONFLICT (fitness={fitness:.6f}), starting batched resolution...")
105
+
106
+ accepted = []
107
+ # Queue of (candidate_list, depth) to process
108
+ pending = [(candidates, 0)]
109
+
110
+ while pending:
111
+ # Collect all pending groups for batch evaluation
112
+ to_eval = []
113
+ for group, depth in pending:
114
+ if len(group) == 0:
115
+ continue
116
+ elif len(group) == 1:
117
+ to_eval.append((group, depth, 'single'))
118
+ else:
119
+ to_eval.append((group, depth, 'group'))
120
+
121
+ pending = []
122
+
123
+ if not to_eval:
124
+ break
125
+
126
+ # Build batch: create model variants for each group
127
+ batch_size = len(to_eval)
128
+ print(f" Batch evaluating {batch_size} groups...")
129
+
130
+ # Create population for batch eval
131
+ pop = {}
132
+ for name, tensor in model.items():
133
+ pop[name] = tensor.unsqueeze(0).expand(batch_size, *tensor.shape).clone().to(device)
134
+
135
+ # Apply each group's reductions to its population slot
136
+ for idx, (group, depth, gtype) in enumerate(to_eval):
137
+ for name, flat_idx, shape, old_val in group:
138
+ new_val = old_val - 1 if old_val > 0 else old_val + 1
139
+ flat_view = pop[name][idx].flatten()
140
+ # Check if not already modified in base model
141
+ base_val = model[name].flatten()[flat_idx].item()
142
+ if base_val == old_val:
143
+ flat_view[flat_idx] = new_val
144
+
145
+ # Batch evaluate
146
+ torch.manual_seed(0)
147
+ fitnesses = evaluator.evaluate(pop, debug=False)
148
+
149
+ # Process results - collect accepted groups first, then verify
150
+ batch_accepted = []
151
+ ok_count = 0
152
+ conflict_count = 0
153
+ fail_count = 0
154
+
155
+ for idx, (group, depth, gtype) in enumerate(to_eval):
156
+ fit = fitnesses[idx].item()
157
+ indent = " " + " " * depth
158
+
159
+ if fit >= 0.9999:
160
+ batch_accepted.append((group, depth, indent))
161
+ ok_count += len(group)
162
+ else:
163
+ if len(group) == 1:
164
+ name, flat_idx, shape, old_val = group[0]
165
+ print(f"{indent}[1/1] FAIL {name}[{flat_idx}] | fitness={fit:.6f}")
166
+ fail_count += 1
167
+ else:
168
+ mid = len(group) // 2
169
+ left = group[:mid]
170
+ right = group[mid:]
171
+ print(f"{indent}CONFLICT ({len(group)}) fitness={fit:.6f} -> split {len(left)}+{len(right)}")
172
+ pending.append((left, depth + 1))
173
+ pending.append((right, depth + 1))
174
+ conflict_count += 1
175
+
176
+ # Apply all batch-accepted reductions
177
+ all_batch_reductions = []
178
+ for group, depth, indent in batch_accepted:
179
+ apply_reductions(model, group)
180
+ all_batch_reductions.extend(group)
181
+
182
+ # Verify cumulative effect
183
+ if all_batch_reductions:
184
+ verify_fitness = check_fitness(model, evaluator, device)
185
+ if verify_fitness >= 0.9999:
186
+ # All good - commit these reductions
187
+ for group, depth, indent in batch_accepted:
188
+ current_mag = sum(t.abs().sum().item() for t in model.values())
189
+ reduction_pct = 100 * (1 - current_mag / base_magnitude)
190
+ if len(group) == 1:
191
+ name, flat_idx, shape, old_val = group[0]
192
+ print(f"{indent}[1/1] OK {name}[{flat_idx}] | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
193
+ else:
194
+ print(f"{indent}ALL {len(group)} OK | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
195
+ accepted.extend(all_batch_reductions)
196
+ print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
197
+ else:
198
+ # Interaction bug detected - revert and use sequential fallback
199
+ print(f" INTERACTION BUG detected (batch fitness={verify_fitness:.6f})")
200
+ print(f" Reverting {len(all_batch_reductions)} reductions, falling back to sequential...")
201
+ revert_reductions(model, all_batch_reductions)
202
+
203
+ # Process each group sequentially
204
+ seq_accepted = sequential_conflict_resolution(
205
+ model, evaluator, device, all_batch_reductions, base_magnitude
206
+ )
207
+ accepted.extend(seq_accepted)
208
+ print(f" Sequential fallback: {len(seq_accepted)}/{len(all_batch_reductions)} accepted")
209
+ else:
210
+ print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
211
+
212
+ return accepted
213
+
214
+
215
+ def prune_weights(
216
+ passes: int = 10,
217
+ batch_size: int = 5000,
218
+ device: str = 'cuda',
219
+ checkpoint_path: str = "D:/8bit-threshold-computer/pruned.safetensors"
220
+ ):
221
+ print("=" * 80)
222
+ print(" BATCHED WEIGHT PRUNING (GPU-optimized)")
223
+ print("=" * 80)
224
+ print(f" Device: {device}")
225
+ print(f" Batch size: {batch_size}")
226
+ print(f" Max passes: {passes}")
227
+ print("=" * 80)
228
+
229
+ # Load model
230
+ print("\n[1/4] LOADING MODEL...")
231
+ load_start = time.perf_counter()
232
+ model = load_model()
233
+ load_time = time.perf_counter() - load_start
234
+
235
+ n_params = sum(t.numel() for t in model.values())
236
+ n_tensors = len(model)
237
+ base_magnitude = sum(t.abs().sum().item() for t in model.values())
238
+ base_max = max(t.abs().max().item() for t in model.values())
239
+ nonzero_params = sum((t != 0).sum().item() for t in model.values())
240
+
241
+ print(f" Loaded in {load_time:.2f}s")
242
+ print(f" Tensors: {n_tensors}")
243
+ print(f" Parameters: {n_params}")
244
+ print(f" Non-zero parameters: {nonzero_params}")
245
+ print(f" Total magnitude: {base_magnitude:.0f}")
246
+ print(f" Max weight: {base_max:.0f}")
247
+
248
+ # Initialize evaluator
249
+ print("\n[2/4] INITIALIZING EVALUATOR...")
250
+ eval_start = time.perf_counter()
251
+ evaluator = BatchedFitnessEvaluator(device=device)
252
+ eval_time = time.perf_counter() - eval_start
253
+ print(f" Initialized in {eval_time:.2f}s")
254
+
255
+ # Verify initial fitness
256
+ print("\n[3/4] VERIFYING BASE MODEL...")
257
+ initial_fitness = check_fitness(model, evaluator, device)
258
+ print(f" Fitness: {initial_fitness:.6f}")
259
+
260
+ if initial_fitness < 0.9999:
261
+ print(f" ERROR: Base model fitness {initial_fitness:.6f} < 0.9999")
262
+ return None
263
+
264
+ print(f" STATUS: PASS")
265
+
266
+ # Build parameter list
267
+ print("\n[4/4] BUILDING PARAMETER INDEX...")
268
+ param_list = []
269
+ for name, tensor in model.items():
270
+ flat = tensor.flatten()
271
+ for i in range(len(flat)):
272
+ param_list.append((name, i, tensor.shape))
273
+ print(f" Indexed {len(param_list)} parameters")
274
+
275
+ # Main pruning loop
276
+ print("\n" + "=" * 80)
277
+ print(" PRUNING STARTED")
278
+ print("=" * 80)
279
+
280
+ total_reductions = 0
281
+ pruning_start = time.perf_counter()
282
+
283
+ for pass_num in range(passes):
284
+ torch.manual_seed(0)
285
+ pass_start = time.perf_counter()
286
+
287
+ print(f"\n{'='*80}")
288
+ print(f" PASS {pass_num + 1}/{passes}")
289
+ print(f"{'='*80}")
290
+
291
+ # Count candidates
292
+ candidates = []
293
+ for name, idx, shape in param_list:
294
+ flat = model[name].flatten()
295
+ val = flat[idx].item()
296
+ if val != 0:
297
+ candidates.append((name, idx, shape, val))
298
+
299
+ n_candidates = len(candidates)
300
+ print(f"\n Candidates: {n_candidates} non-zero weights")
301
+
302
+ if n_candidates == 0:
303
+ print(f" No candidates remaining. Stopping.")
304
+ break
305
+
306
+ # Phase 1: Batch evaluation
307
+ print(f"\n PHASE 1: Batch evaluation (testing each reduction independently)")
308
+ print(f" " + "-" * 60)
309
+ phase1_start = time.perf_counter()
310
+ successful_candidates = []
311
+ n_batches = (n_candidates + batch_size - 1) // batch_size
312
+
313
+ for batch_idx, batch_start_idx in enumerate(range(0, n_candidates, batch_size)):
314
+ batch = candidates[batch_start_idx:batch_start_idx + batch_size]
315
+ batch_len = len(batch)
316
+ batch_start_time = time.perf_counter()
317
+
318
+ # Build population
319
+ pop = {}
320
+ for name, tensor in model.items():
321
+ pop[name] = tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device)
322
+
323
+ # Apply reductions
324
+ for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
325
+ new_val = old_val - 1 if old_val > 0 else old_val + 1
326
+ flat_view = pop[name][pop_idx].flatten()
327
+ flat_view[flat_idx] = new_val
328
+
329
+ # Evaluate
330
+ torch.manual_seed(0)
331
+ if device == 'cuda':
332
+ torch.cuda.synchronize()
333
+ fitness = evaluator.evaluate(pop, debug=False)
334
+ if device == 'cuda':
335
+ torch.cuda.synchronize()
336
+
337
+ # Collect successes
338
+ batch_successes = 0
339
+ for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
340
+ if fitness[pop_idx].item() >= 0.9999:
341
+ successful_candidates.append((name, flat_idx, shape, old_val))
342
+ batch_successes += 1
343
+
344
+ batch_time = time.perf_counter() - batch_start_time
345
+ elapsed = time.perf_counter() - phase1_start
346
+ done = batch_start_idx + batch_len
347
+ eta = format_eta(elapsed, done, n_candidates)
348
+ throughput = batch_len / batch_time
349
+
350
+ print(f" Batch {batch_idx + 1}/{n_batches}: "
351
+ f"{batch_successes}/{batch_len} passed ({100*batch_successes/batch_len:.1f}%) | "
352
+ f"Total OK: {len(successful_candidates)} | "
353
+ f"Progress: {done}/{n_candidates} ({100*done/n_candidates:.1f}%) | "
354
+ f"Speed: {throughput:.0f}/s | "
355
+ f"ETA: {eta}")
356
+
357
+ phase1_time = time.perf_counter() - phase1_start
358
+ print(f"\n Phase 1 complete: {len(successful_candidates)}/{n_candidates} candidates "
359
+ f"({100*len(successful_candidates)/n_candidates:.1f}%) in {format_time(phase1_time)}")
360
+
361
+ # Phase 2: Apply with conflict resolution
362
+ if len(successful_candidates) == 0:
363
+ print(f"\n No reductions possible. Stopping.")
364
+ break
365
+
366
+ print(f"\n PHASE 2: Apply reductions with conflict resolution")
367
+ print(f" " + "-" * 60)
368
+ phase2_start = time.perf_counter()
369
+
370
+ accepted = batched_conflict_resolution(model, evaluator, device, successful_candidates, base_magnitude)
371
+ pass_reductions = len(accepted)
372
+
373
+ phase2_time = time.perf_counter() - phase2_start
374
+ print(f"\n Phase 2 complete: {pass_reductions} reductions applied in {format_time(phase2_time)}")
375
+
376
+ # Pass summary
377
+ total_reductions += pass_reductions
378
+ current_magnitude = sum(t.abs().sum().item() for t in model.values())
379
+ current_nonzero = sum((t != 0).sum().item() for t in model.values())
380
+ pass_time = time.perf_counter() - pass_start
381
+ reduction_pct = 100 * (1 - current_magnitude / base_magnitude)
382
+
383
+ print(f"\n PASS {pass_num + 1} SUMMARY:")
384
+ print(f" Reductions this pass: {pass_reductions}")
385
+ print(f" Total reductions: {total_reductions}")
386
+ print(f" Current magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
387
+ print(f" Current non-zero: {current_nonzero}")
388
+ print(f" Pass time: {format_time(pass_time)}")
389
+
390
+ # Verify after pass
391
+ print(f"\n Verifying model integrity...")
392
+ fitness = check_fitness(model, evaluator, device)
393
+ print(f" Fitness: {fitness:.6f} {'PASS' if fitness >= 0.9999 else 'FAIL'}")
394
+
395
+ # Save checkpoint after each pass
396
+ checkpoint_name = checkpoint_path.replace('.safetensors', f'_pass{pass_num + 1}.safetensors')
397
+ print(f"\n Saving checkpoint: {checkpoint_name}")
398
+ save_file(model, checkpoint_name)
399
+ print(f" Saved. Magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
400
+
401
+ # Also save as "latest" for easy access
402
+ latest_path = checkpoint_path.replace('.safetensors', '_latest.safetensors')
403
+ save_file(model, latest_path)
404
+ print(f" Also saved as: {latest_path}")
405
+
406
+ if pass_reductions == 0:
407
+ print(f"\n No reductions achieved. Stopping early.")
408
+ break
409
+
410
+ # Final summary
411
+ pruning_time = time.perf_counter() - pruning_start
412
+ final_magnitude = sum(t.abs().sum().item() for t in model.values())
413
+ final_max = max(t.abs().max().item() for t in model.values())
414
+ final_nonzero = sum((t != 0).sum().item() for t in model.values())
415
+ reduction_pct = 100 * (1 - final_magnitude / base_magnitude)
416
+
417
+ print("\n" + "=" * 80)
418
+ print(" PRUNING COMPLETE")
419
+ print("=" * 80)
420
+ print(f"\n RESULTS:")
421
+ print(f" Original magnitude: {base_magnitude:.0f}")
422
+ print(f" Final magnitude: {final_magnitude:.0f}")
423
+ print(f" Reduction: {reduction_pct:.2f}%")
424
+ print(f" Total reductions: {total_reductions}")
425
+ print(f" Original non-zero: {nonzero_params}")
426
+ print(f" Final non-zero: {final_nonzero}")
427
+ print(f" Zeros created: {nonzero_params - final_nonzero}")
428
+ print(f" Max weight: {final_max:.0f}")
429
+ print(f" Total time: {format_time(pruning_time)}")
430
+
431
+ # Save
432
+ print(f"\n SAVING to {checkpoint_path}...")
433
+ save_file(model, checkpoint_path)
434
+ print(f" Saved.")
435
+
436
+ # Final verification
437
+ print(f"\n FINAL VERIFICATION...")
438
+ from safetensors import safe_open
439
+ f = safe_open(checkpoint_path, framework='numpy')
440
+ verify_model = {name: torch.tensor(f.get_tensor(name)).float() for name in f.keys()}
441
+ verify_fitness = check_fitness(verify_model, evaluator, device)
442
+ print(f" Fitness: {verify_fitness:.6f}")
443
+
444
+ if verify_fitness >= 0.9999:
445
+ print(f" STATUS: PASS")
446
+ else:
447
+ print(f" STATUS: FAIL - Model corrupted!")
448
+
449
+ print("\n" + "=" * 80)
450
+ return model
451
+
452
+
453
+ MAX_BATCH_SIZE = 80000
454
+
455
+ if __name__ == "__main__":
456
+ parser = argparse.ArgumentParser(description='Batched Weight Pruning')
457
+ parser.add_argument('--passes', type=int, default=10,
458
+ help='Maximum pruning passes (default: 10)')
459
+ parser.add_argument('--batch_size', type=int, default=80000,
460
+ help=f'Batch size for parallel evaluation (default: 80000, max: {MAX_BATCH_SIZE})')
461
+ parser.add_argument('--device', type=str, default='cuda',
462
+ help='Device: cuda or cpu (default: cuda)')
463
+ parser.add_argument('--output', type=str,
464
+ default='D:/8bit-threshold-computer/pruned.safetensors',
465
+ help='Output path')
466
+ args = parser.parse_args()
467
+
468
+ if args.batch_size > MAX_BATCH_SIZE:
469
+ print(f"WARNING: batch_size {args.batch_size} exceeds maximum {MAX_BATCH_SIZE}. Clamping.")
470
+ args.batch_size = MAX_BATCH_SIZE
471
+
472
+ print(f"\nStarting at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
473
+
474
+ prune_weights(
475
+ passes=args.passes,
476
+ batch_size=args.batch_size,
477
+ device=args.device,
478
+ checkpoint_path=args.output
479
+ )
480
+
481
+ print(f"\nFinished at {time.strftime('%Y-%m-%d %H:%M:%S')}")