| """
|
| BATCHED WEIGHT PRUNING (GPU-optimized)
|
| ======================================
|
| Phase 1: Batch eval all candidates in parallel
|
| Phase 2: Apply all successes at once, binary search if conflicts
|
| """
|
|
|
| import torch
|
| import time
|
| import argparse
|
| from safetensors.torch import save_file
|
| from iron_eval import BatchedFitnessEvaluator, create_population, load_model
|
|
|
| torch.manual_seed(0)
|
|
|
|
|
| def format_time(seconds):
|
| if seconds < 60:
|
| return f"{seconds:.1f}s"
|
| elif seconds < 3600:
|
| return f"{seconds/60:.1f}m"
|
| else:
|
| return f"{seconds/3600:.1f}h"
|
|
|
|
|
| def format_eta(elapsed, done, total):
|
| if done == 0:
|
| return "calculating..."
|
| rate = done / elapsed
|
| remaining = (total - done) / rate
|
| return format_time(remaining)
|
|
|
|
|
| def apply_reductions(model, reductions):
|
| """Apply a list of (name, flat_idx, shape, old_val) reductions."""
|
| for name, flat_idx, shape, old_val in reductions:
|
| new_val = old_val - 1 if old_val > 0 else old_val + 1
|
| flat = model[name].flatten()
|
| if flat[flat_idx].item() == old_val:
|
| flat[flat_idx] = new_val
|
| model[name] = flat.view(shape)
|
|
|
|
|
| def revert_reductions(model, reductions):
|
| """Revert a list of reductions."""
|
| for name, flat_idx, shape, old_val in reductions:
|
| flat = model[name].flatten()
|
| new_val = old_val - 1 if old_val > 0 else old_val + 1
|
| if flat[flat_idx].item() == new_val:
|
| flat[flat_idx] = old_val
|
| model[name] = flat.view(shape)
|
|
|
|
|
| def check_fitness(model, evaluator, device):
|
| """Check model fitness."""
|
| torch.manual_seed(0)
|
| pop = create_population(model, 1, device)
|
| return evaluator.evaluate(pop, debug=False)[0].item()
|
|
|
|
|
| def sequential_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
|
| """
|
| Sequential fallback - tests and applies reductions one at a time.
|
| Slower but guarantees no interaction bugs.
|
| """
|
| accepted = []
|
| for i, (name, flat_idx, shape, old_val) in enumerate(candidates):
|
| apply_reductions(model, [(name, flat_idx, shape, old_val)])
|
| fitness = check_fitness(model, evaluator, device)
|
| if fitness >= 0.9999:
|
| accepted.append((name, flat_idx, shape, old_val))
|
| if (i + 1) % 50 == 0:
|
| current_mag = sum(t.abs().sum().item() for t in model.values())
|
| reduction_pct = 100 * (1 - current_mag / base_magnitude)
|
| print(f" Sequential: {len(accepted)}/{i+1} accepted | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
|
| else:
|
| revert_reductions(model, [(name, flat_idx, shape, old_val)])
|
| return accepted
|
|
|
|
|
| def batched_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
|
| """
|
| Batched binary search - evaluates multiple branches in parallel.
|
| Uses BFS instead of DFS to maximize batching opportunities.
|
| Verifies cumulative effect after each batch to prevent interaction bugs.
|
| """
|
| if len(candidates) == 0:
|
| return []
|
|
|
|
|
| print(f" Trying {len(candidates)} reductions at once...")
|
| apply_reductions(model, candidates)
|
| fitness = check_fitness(model, evaluator, device)
|
|
|
| if fitness >= 0.9999:
|
| current_mag = sum(t.abs().sum().item() for t in model.values())
|
| reduction_pct = 100 * (1 - current_mag / base_magnitude)
|
| print(f" ALL {len(candidates)} OK | fitness={fitness:.6f} | "
|
| f"mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
|
| return candidates
|
|
|
|
|
| revert_reductions(model, candidates)
|
| print(f" CONFLICT (fitness={fitness:.6f}), starting batched resolution...")
|
|
|
| accepted = []
|
|
|
| pending = [(candidates, 0)]
|
|
|
| while pending:
|
|
|
| to_eval = []
|
| for group, depth in pending:
|
| if len(group) == 0:
|
| continue
|
| elif len(group) == 1:
|
| to_eval.append((group, depth, 'single'))
|
| else:
|
| to_eval.append((group, depth, 'group'))
|
|
|
| pending = []
|
|
|
| if not to_eval:
|
| break
|
|
|
|
|
| batch_size = len(to_eval)
|
| print(f" Batch evaluating {batch_size} groups...")
|
|
|
|
|
| pop = {}
|
| for name, tensor in model.items():
|
| pop[name] = tensor.unsqueeze(0).expand(batch_size, *tensor.shape).clone().to(device)
|
|
|
|
|
| for idx, (group, depth, gtype) in enumerate(to_eval):
|
| for name, flat_idx, shape, old_val in group:
|
| new_val = old_val - 1 if old_val > 0 else old_val + 1
|
| flat_view = pop[name][idx].flatten()
|
|
|
| base_val = model[name].flatten()[flat_idx].item()
|
| if base_val == old_val:
|
| flat_view[flat_idx] = new_val
|
|
|
|
|
| torch.manual_seed(0)
|
| fitnesses = evaluator.evaluate(pop, debug=False)
|
|
|
|
|
| batch_accepted = []
|
| ok_count = 0
|
| conflict_count = 0
|
| fail_count = 0
|
|
|
| for idx, (group, depth, gtype) in enumerate(to_eval):
|
| fit = fitnesses[idx].item()
|
| indent = " " + " " * depth
|
|
|
| if fit >= 0.9999:
|
| batch_accepted.append((group, depth, indent))
|
| ok_count += len(group)
|
| else:
|
| if len(group) == 1:
|
| name, flat_idx, shape, old_val = group[0]
|
| print(f"{indent}[1/1] FAIL {name}[{flat_idx}] | fitness={fit:.6f}")
|
| fail_count += 1
|
| else:
|
| mid = len(group) // 2
|
| left = group[:mid]
|
| right = group[mid:]
|
| print(f"{indent}CONFLICT ({len(group)}) fitness={fit:.6f} -> split {len(left)}+{len(right)}")
|
| pending.append((left, depth + 1))
|
| pending.append((right, depth + 1))
|
| conflict_count += 1
|
|
|
|
|
| all_batch_reductions = []
|
| for group, depth, indent in batch_accepted:
|
| apply_reductions(model, group)
|
| all_batch_reductions.extend(group)
|
|
|
|
|
| if all_batch_reductions:
|
| verify_fitness = check_fitness(model, evaluator, device)
|
| if verify_fitness >= 0.9999:
|
|
|
| for group, depth, indent in batch_accepted:
|
| current_mag = sum(t.abs().sum().item() for t in model.values())
|
| reduction_pct = 100 * (1 - current_mag / base_magnitude)
|
| if len(group) == 1:
|
| name, flat_idx, shape, old_val = group[0]
|
| print(f"{indent}[1/1] OK {name}[{flat_idx}] | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
|
| else:
|
| print(f"{indent}ALL {len(group)} OK | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
|
| accepted.extend(all_batch_reductions)
|
| print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
|
| else:
|
|
|
| print(f" INTERACTION BUG detected (batch fitness={verify_fitness:.6f})")
|
| print(f" Reverting {len(all_batch_reductions)} reductions, falling back to sequential...")
|
| revert_reductions(model, all_batch_reductions)
|
|
|
|
|
| seq_accepted = sequential_conflict_resolution(
|
| model, evaluator, device, all_batch_reductions, base_magnitude
|
| )
|
| accepted.extend(seq_accepted)
|
| print(f" Sequential fallback: {len(seq_accepted)}/{len(all_batch_reductions)} accepted")
|
| else:
|
| print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
|
|
|
| return accepted
|
|
|
|
|
| def prune_weights(
|
| passes: int = 10,
|
| batch_size: int = 5000,
|
| device: str = 'cuda',
|
| checkpoint_path: str = "D:/8bit-threshold-computer/pruned.safetensors"
|
| ):
|
| print("=" * 80)
|
| print(" BATCHED WEIGHT PRUNING (GPU-optimized)")
|
| print("=" * 80)
|
| print(f" Device: {device}")
|
| print(f" Batch size: {batch_size}")
|
| print(f" Max passes: {passes}")
|
| print("=" * 80)
|
|
|
|
|
| print("\n[1/4] LOADING MODEL...")
|
| load_start = time.perf_counter()
|
| model = load_model()
|
| load_time = time.perf_counter() - load_start
|
|
|
| n_params = sum(t.numel() for t in model.values())
|
| n_tensors = len(model)
|
| base_magnitude = sum(t.abs().sum().item() for t in model.values())
|
| base_max = max(t.abs().max().item() for t in model.values())
|
| nonzero_params = sum((t != 0).sum().item() for t in model.values())
|
|
|
| print(f" Loaded in {load_time:.2f}s")
|
| print(f" Tensors: {n_tensors}")
|
| print(f" Parameters: {n_params}")
|
| print(f" Non-zero parameters: {nonzero_params}")
|
| print(f" Total magnitude: {base_magnitude:.0f}")
|
| print(f" Max weight: {base_max:.0f}")
|
|
|
|
|
| print("\n[2/4] INITIALIZING EVALUATOR...")
|
| eval_start = time.perf_counter()
|
| evaluator = BatchedFitnessEvaluator(device=device)
|
| eval_time = time.perf_counter() - eval_start
|
| print(f" Initialized in {eval_time:.2f}s")
|
|
|
|
|
| print("\n[3/4] VERIFYING BASE MODEL...")
|
| initial_fitness = check_fitness(model, evaluator, device)
|
| print(f" Fitness: {initial_fitness:.6f}")
|
|
|
| if initial_fitness < 0.9999:
|
| print(f" ERROR: Base model fitness {initial_fitness:.6f} < 0.9999")
|
| return None
|
|
|
| print(f" STATUS: PASS")
|
|
|
|
|
| print("\n[4/4] BUILDING PARAMETER INDEX...")
|
| param_list = []
|
| for name, tensor in model.items():
|
| flat = tensor.flatten()
|
| for i in range(len(flat)):
|
| param_list.append((name, i, tensor.shape))
|
| print(f" Indexed {len(param_list)} parameters")
|
|
|
|
|
| print("\n" + "=" * 80)
|
| print(" PRUNING STARTED")
|
| print("=" * 80)
|
|
|
| total_reductions = 0
|
| pruning_start = time.perf_counter()
|
|
|
| for pass_num in range(passes):
|
| torch.manual_seed(0)
|
| pass_start = time.perf_counter()
|
|
|
| print(f"\n{'='*80}")
|
| print(f" PASS {pass_num + 1}/{passes}")
|
| print(f"{'='*80}")
|
|
|
|
|
| candidates = []
|
| for name, idx, shape in param_list:
|
| flat = model[name].flatten()
|
| val = flat[idx].item()
|
| if val != 0:
|
| candidates.append((name, idx, shape, val))
|
|
|
| n_candidates = len(candidates)
|
| print(f"\n Candidates: {n_candidates} non-zero weights")
|
|
|
| if n_candidates == 0:
|
| print(f" No candidates remaining. Stopping.")
|
| break
|
|
|
|
|
| print(f"\n PHASE 1: Batch evaluation (testing each reduction independently)")
|
| print(f" " + "-" * 60)
|
| phase1_start = time.perf_counter()
|
| successful_candidates = []
|
| n_batches = (n_candidates + batch_size - 1) // batch_size
|
|
|
| for batch_idx, batch_start_idx in enumerate(range(0, n_candidates, batch_size)):
|
| batch = candidates[batch_start_idx:batch_start_idx + batch_size]
|
| batch_len = len(batch)
|
| batch_start_time = time.perf_counter()
|
|
|
|
|
| pop = {}
|
| for name, tensor in model.items():
|
| pop[name] = tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device)
|
|
|
|
|
| for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
|
| new_val = old_val - 1 if old_val > 0 else old_val + 1
|
| flat_view = pop[name][pop_idx].flatten()
|
| flat_view[flat_idx] = new_val
|
|
|
|
|
| torch.manual_seed(0)
|
| if device == 'cuda':
|
| torch.cuda.synchronize()
|
| fitness = evaluator.evaluate(pop, debug=False)
|
| if device == 'cuda':
|
| torch.cuda.synchronize()
|
|
|
|
|
| batch_successes = 0
|
| for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
|
| if fitness[pop_idx].item() >= 0.9999:
|
| successful_candidates.append((name, flat_idx, shape, old_val))
|
| batch_successes += 1
|
|
|
| batch_time = time.perf_counter() - batch_start_time
|
| elapsed = time.perf_counter() - phase1_start
|
| done = batch_start_idx + batch_len
|
| eta = format_eta(elapsed, done, n_candidates)
|
| throughput = batch_len / batch_time
|
|
|
| print(f" Batch {batch_idx + 1}/{n_batches}: "
|
| f"{batch_successes}/{batch_len} passed ({100*batch_successes/batch_len:.1f}%) | "
|
| f"Total OK: {len(successful_candidates)} | "
|
| f"Progress: {done}/{n_candidates} ({100*done/n_candidates:.1f}%) | "
|
| f"Speed: {throughput:.0f}/s | "
|
| f"ETA: {eta}")
|
|
|
| phase1_time = time.perf_counter() - phase1_start
|
| print(f"\n Phase 1 complete: {len(successful_candidates)}/{n_candidates} candidates "
|
| f"({100*len(successful_candidates)/n_candidates:.1f}%) in {format_time(phase1_time)}")
|
|
|
|
|
| if len(successful_candidates) == 0:
|
| print(f"\n No reductions possible. Stopping.")
|
| break
|
|
|
| print(f"\n PHASE 2: Apply reductions with conflict resolution")
|
| print(f" " + "-" * 60)
|
| phase2_start = time.perf_counter()
|
|
|
| accepted = batched_conflict_resolution(model, evaluator, device, successful_candidates, base_magnitude)
|
| pass_reductions = len(accepted)
|
|
|
| phase2_time = time.perf_counter() - phase2_start
|
| print(f"\n Phase 2 complete: {pass_reductions} reductions applied in {format_time(phase2_time)}")
|
|
|
|
|
| total_reductions += pass_reductions
|
| current_magnitude = sum(t.abs().sum().item() for t in model.values())
|
| current_nonzero = sum((t != 0).sum().item() for t in model.values())
|
| pass_time = time.perf_counter() - pass_start
|
| reduction_pct = 100 * (1 - current_magnitude / base_magnitude)
|
|
|
| print(f"\n PASS {pass_num + 1} SUMMARY:")
|
| print(f" Reductions this pass: {pass_reductions}")
|
| print(f" Total reductions: {total_reductions}")
|
| print(f" Current magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
|
| print(f" Current non-zero: {current_nonzero}")
|
| print(f" Pass time: {format_time(pass_time)}")
|
|
|
|
|
| print(f"\n Verifying model integrity...")
|
| fitness = check_fitness(model, evaluator, device)
|
| print(f" Fitness: {fitness:.6f} {'PASS' if fitness >= 0.9999 else 'FAIL'}")
|
|
|
|
|
| checkpoint_name = checkpoint_path.replace('.safetensors', f'_pass{pass_num + 1}.safetensors')
|
| print(f"\n Saving checkpoint: {checkpoint_name}")
|
| save_file(model, checkpoint_name)
|
| print(f" Saved. Magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
|
|
|
|
|
| latest_path = checkpoint_path.replace('.safetensors', '_latest.safetensors')
|
| save_file(model, latest_path)
|
| print(f" Also saved as: {latest_path}")
|
|
|
| if pass_reductions == 0:
|
| print(f"\n No reductions achieved. Stopping early.")
|
| break
|
|
|
|
|
| pruning_time = time.perf_counter() - pruning_start
|
| final_magnitude = sum(t.abs().sum().item() for t in model.values())
|
| final_max = max(t.abs().max().item() for t in model.values())
|
| final_nonzero = sum((t != 0).sum().item() for t in model.values())
|
| reduction_pct = 100 * (1 - final_magnitude / base_magnitude)
|
|
|
| print("\n" + "=" * 80)
|
| print(" PRUNING COMPLETE")
|
| print("=" * 80)
|
| print(f"\n RESULTS:")
|
| print(f" Original magnitude: {base_magnitude:.0f}")
|
| print(f" Final magnitude: {final_magnitude:.0f}")
|
| print(f" Reduction: {reduction_pct:.2f}%")
|
| print(f" Total reductions: {total_reductions}")
|
| print(f" Original non-zero: {nonzero_params}")
|
| print(f" Final non-zero: {final_nonzero}")
|
| print(f" Zeros created: {nonzero_params - final_nonzero}")
|
| print(f" Max weight: {final_max:.0f}")
|
| print(f" Total time: {format_time(pruning_time)}")
|
|
|
|
|
| print(f"\n SAVING to {checkpoint_path}...")
|
| save_file(model, checkpoint_path)
|
| print(f" Saved.")
|
|
|
|
|
| print(f"\n FINAL VERIFICATION...")
|
| from safetensors import safe_open
|
| f = safe_open(checkpoint_path, framework='numpy')
|
| verify_model = {name: torch.tensor(f.get_tensor(name)).float() for name in f.keys()}
|
| verify_fitness = check_fitness(verify_model, evaluator, device)
|
| print(f" Fitness: {verify_fitness:.6f}")
|
|
|
| if verify_fitness >= 0.9999:
|
| print(f" STATUS: PASS")
|
| else:
|
| print(f" STATUS: FAIL - Model corrupted!")
|
|
|
| print("\n" + "=" * 80)
|
| return model
|
|
|
|
|
| MAX_BATCH_SIZE = 80000
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser(description='Batched Weight Pruning')
|
| parser.add_argument('--passes', type=int, default=10,
|
| help='Maximum pruning passes (default: 10)')
|
| parser.add_argument('--batch_size', type=int, default=80000,
|
| help=f'Batch size for parallel evaluation (default: 80000, max: {MAX_BATCH_SIZE})')
|
| parser.add_argument('--device', type=str, default='cuda',
|
| help='Device: cuda or cpu (default: cuda)')
|
| parser.add_argument('--output', type=str,
|
| default='D:/8bit-threshold-computer/pruned.safetensors',
|
| help='Output path')
|
| args = parser.parse_args()
|
|
|
| if args.batch_size > MAX_BATCH_SIZE:
|
| print(f"WARNING: batch_size {args.batch_size} exceeds maximum {MAX_BATCH_SIZE}. Clamping.")
|
| args.batch_size = MAX_BATCH_SIZE
|
|
|
| print(f"\nStarting at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
|
| prune_weights(
|
| passes=args.passes,
|
| batch_size=args.batch_size,
|
| device=args.device,
|
| checkpoint_path=args.output
|
| )
|
|
|
| print(f"\nFinished at {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
|