""" BATCHED WEIGHT PRUNING (GPU-optimized) ====================================== Phase 1: Batch eval all candidates in parallel Phase 2: Apply all successes at once, binary search if conflicts """ import torch import time import argparse from safetensors.torch import save_file from eval import BatchedFitnessEvaluator, create_population, load_model torch.manual_seed(0) def format_time(seconds): if seconds < 60: return f"{seconds:.1f}s" elif seconds < 3600: return f"{seconds/60:.1f}m" else: return f"{seconds/3600:.1f}h" def format_eta(elapsed, done, total): if done == 0: return "calculating..." rate = done / elapsed remaining = (total - done) / rate return format_time(remaining) def apply_reductions(model, reductions): """Apply a list of (name, flat_idx, shape, old_val) reductions.""" for name, flat_idx, shape, old_val in reductions: new_val = old_val - 1 if old_val > 0 else old_val + 1 flat = model[name].flatten() if flat[flat_idx].item() == old_val: flat[flat_idx] = new_val model[name] = flat.view(shape) def revert_reductions(model, reductions): """Revert a list of reductions.""" for name, flat_idx, shape, old_val in reductions: flat = model[name].flatten() new_val = old_val - 1 if old_val > 0 else old_val + 1 if flat[flat_idx].item() == new_val: flat[flat_idx] = old_val model[name] = flat.view(shape) def check_fitness(model, evaluator, device): """Check model fitness.""" torch.manual_seed(0) pop = create_population(model, 1, device) return evaluator.evaluate(pop, debug=False)[0].item() def sequential_conflict_resolution(model, evaluator, device, candidates, base_magnitude): """ Sequential fallback - tests and applies reductions one at a time. Slower but guarantees no interaction bugs. """ accepted = [] for i, (name, flat_idx, shape, old_val) in enumerate(candidates): apply_reductions(model, [(name, flat_idx, shape, old_val)]) fitness = check_fitness(model, evaluator, device) if fitness >= 0.9999: accepted.append((name, flat_idx, shape, old_val)) if (i + 1) % 50 == 0: current_mag = sum(t.abs().sum().item() for t in model.values()) reduction_pct = 100 * (1 - current_mag / base_magnitude) print(f" Sequential: {len(accepted)}/{i+1} accepted | mag={current_mag:.0f} (-{reduction_pct:.2f}%)") else: revert_reductions(model, [(name, flat_idx, shape, old_val)]) return accepted def batched_conflict_resolution(model, evaluator, device, candidates, base_magnitude): """ Batched binary search - evaluates multiple branches in parallel. Uses BFS instead of DFS to maximize batching opportunities. Verifies cumulative effect after each batch to prevent interaction bugs. """ if len(candidates) == 0: return [] # First try all at once print(f" Trying {len(candidates)} reductions at once...") apply_reductions(model, candidates) fitness = check_fitness(model, evaluator, device) if fitness >= 0.9999: current_mag = sum(t.abs().sum().item() for t in model.values()) reduction_pct = 100 * (1 - current_mag / base_magnitude) print(f" ALL {len(candidates)} OK | fitness={fitness:.6f} | " f"mag={current_mag:.0f} (-{reduction_pct:.2f}%)") return candidates # Conflict - revert and use batched BFS revert_reductions(model, candidates) print(f" CONFLICT (fitness={fitness:.6f}), starting batched resolution...") accepted = [] # Queue of (candidate_list, depth) to process pending = [(candidates, 0)] while pending: # Collect all pending groups for batch evaluation to_eval = [] for group, depth in pending: if len(group) == 0: continue elif len(group) == 1: to_eval.append((group, depth, 'single')) else: to_eval.append((group, depth, 'group')) pending = [] if not to_eval: break # Build batch: create model variants for each group batch_size = len(to_eval) print(f" Batch evaluating {batch_size} groups...") # Create population for batch eval pop = {} for name, tensor in model.items(): pop[name] = tensor.unsqueeze(0).expand(batch_size, *tensor.shape).clone().to(device) # Apply each group's reductions to its population slot for idx, (group, depth, gtype) in enumerate(to_eval): for name, flat_idx, shape, old_val in group: new_val = old_val - 1 if old_val > 0 else old_val + 1 flat_view = pop[name][idx].flatten() # Check if not already modified in base model base_val = model[name].flatten()[flat_idx].item() if base_val == old_val: flat_view[flat_idx] = new_val # Batch evaluate torch.manual_seed(0) fitnesses = evaluator.evaluate(pop, debug=False) # Process results - collect accepted groups first, then verify batch_accepted = [] ok_count = 0 conflict_count = 0 fail_count = 0 for idx, (group, depth, gtype) in enumerate(to_eval): fit = fitnesses[idx].item() indent = " " + " " * depth if fit >= 0.9999: batch_accepted.append((group, depth, indent)) ok_count += len(group) else: if len(group) == 1: name, flat_idx, shape, old_val = group[0] print(f"{indent}[1/1] FAIL {name}[{flat_idx}] | fitness={fit:.6f}") fail_count += 1 else: mid = len(group) // 2 left = group[:mid] right = group[mid:] print(f"{indent}CONFLICT ({len(group)}) fitness={fit:.6f} -> split {len(left)}+{len(right)}") pending.append((left, depth + 1)) pending.append((right, depth + 1)) conflict_count += 1 # Apply all batch-accepted reductions all_batch_reductions = [] for group, depth, indent in batch_accepted: apply_reductions(model, group) all_batch_reductions.extend(group) # Verify cumulative effect if all_batch_reductions: verify_fitness = check_fitness(model, evaluator, device) if verify_fitness >= 0.9999: # All good - commit these reductions for group, depth, indent in batch_accepted: current_mag = sum(t.abs().sum().item() for t in model.values()) reduction_pct = 100 * (1 - current_mag / base_magnitude) if len(group) == 1: name, flat_idx, shape, old_val = group[0] print(f"{indent}[1/1] OK {name}[{flat_idx}] | mag={current_mag:.0f} (-{reduction_pct:.2f}%)") else: print(f"{indent}ALL {len(group)} OK | mag={current_mag:.0f} (-{reduction_pct:.2f}%)") accepted.extend(all_batch_reductions) print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed") else: # Interaction bug detected - revert and use sequential fallback print(f" INTERACTION BUG detected (batch fitness={verify_fitness:.6f})") print(f" Reverting {len(all_batch_reductions)} reductions, falling back to sequential...") revert_reductions(model, all_batch_reductions) # Process each group sequentially seq_accepted = sequential_conflict_resolution( model, evaluator, device, all_batch_reductions, base_magnitude ) accepted.extend(seq_accepted) print(f" Sequential fallback: {len(seq_accepted)}/{len(all_batch_reductions)} accepted") else: print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed") return accepted def prune_weights( passes: int = 10, batch_size: int = 5000, device: str = 'cuda', checkpoint_path: str = "D:/8bit-threshold-computer/pruned.safetensors" ): print("=" * 80) print(" BATCHED WEIGHT PRUNING (GPU-optimized)") print("=" * 80) print(f" Device: {device}") print(f" Batch size: {batch_size}") print(f" Max passes: {passes}") print("=" * 80) # Load model print("\n[1/4] LOADING MODEL...") load_start = time.perf_counter() model = load_model() load_time = time.perf_counter() - load_start n_params = sum(t.numel() for t in model.values()) n_tensors = len(model) base_magnitude = sum(t.abs().sum().item() for t in model.values()) base_max = max(t.abs().max().item() for t in model.values()) nonzero_params = sum((t != 0).sum().item() for t in model.values()) print(f" Loaded in {load_time:.2f}s") print(f" Tensors: {n_tensors}") print(f" Parameters: {n_params}") print(f" Non-zero parameters: {nonzero_params}") print(f" Total magnitude: {base_magnitude:.0f}") print(f" Max weight: {base_max:.0f}") # Initialize evaluator print("\n[2/4] INITIALIZING EVALUATOR...") eval_start = time.perf_counter() evaluator = BatchedFitnessEvaluator(device=device) eval_time = time.perf_counter() - eval_start print(f" Initialized in {eval_time:.2f}s") # Verify initial fitness print("\n[3/4] VERIFYING BASE MODEL...") initial_fitness = check_fitness(model, evaluator, device) print(f" Fitness: {initial_fitness:.6f}") if initial_fitness < 0.9999: print(f" ERROR: Base model fitness {initial_fitness:.6f} < 0.9999") return None print(f" STATUS: PASS") # Build parameter list print("\n[4/4] BUILDING PARAMETER INDEX...") param_list = [] for name, tensor in model.items(): flat = tensor.flatten() for i in range(len(flat)): param_list.append((name, i, tensor.shape)) print(f" Indexed {len(param_list)} parameters") # Main pruning loop print("\n" + "=" * 80) print(" PRUNING STARTED") print("=" * 80) total_reductions = 0 pruning_start = time.perf_counter() for pass_num in range(passes): torch.manual_seed(0) pass_start = time.perf_counter() print(f"\n{'='*80}") print(f" PASS {pass_num + 1}/{passes}") print(f"{'='*80}") # Count candidates candidates = [] for name, idx, shape in param_list: flat = model[name].flatten() val = flat[idx].item() if val != 0: candidates.append((name, idx, shape, val)) n_candidates = len(candidates) print(f"\n Candidates: {n_candidates} non-zero weights") if n_candidates == 0: print(f" No candidates remaining. Stopping.") break # Phase 1: Batch evaluation print(f"\n PHASE 1: Batch evaluation (testing each reduction independently)") print(f" " + "-" * 60) phase1_start = time.perf_counter() successful_candidates = [] n_batches = (n_candidates + batch_size - 1) // batch_size for batch_idx, batch_start_idx in enumerate(range(0, n_candidates, batch_size)): batch = candidates[batch_start_idx:batch_start_idx + batch_size] batch_len = len(batch) batch_start_time = time.perf_counter() # Build population pop = {} for name, tensor in model.items(): pop[name] = tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device) # Apply reductions for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch): new_val = old_val - 1 if old_val > 0 else old_val + 1 flat_view = pop[name][pop_idx].flatten() flat_view[flat_idx] = new_val # Evaluate torch.manual_seed(0) if device == 'cuda': torch.cuda.synchronize() fitness = evaluator.evaluate(pop, debug=False) if device == 'cuda': torch.cuda.synchronize() # Collect successes batch_successes = 0 for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch): if fitness[pop_idx].item() >= 0.9999: successful_candidates.append((name, flat_idx, shape, old_val)) batch_successes += 1 batch_time = time.perf_counter() - batch_start_time elapsed = time.perf_counter() - phase1_start done = batch_start_idx + batch_len eta = format_eta(elapsed, done, n_candidates) throughput = batch_len / batch_time print(f" Batch {batch_idx + 1}/{n_batches}: " f"{batch_successes}/{batch_len} passed ({100*batch_successes/batch_len:.1f}%) | " f"Total OK: {len(successful_candidates)} | " f"Progress: {done}/{n_candidates} ({100*done/n_candidates:.1f}%) | " f"Speed: {throughput:.0f}/s | " f"ETA: {eta}") phase1_time = time.perf_counter() - phase1_start print(f"\n Phase 1 complete: {len(successful_candidates)}/{n_candidates} candidates " f"({100*len(successful_candidates)/n_candidates:.1f}%) in {format_time(phase1_time)}") # Phase 2: Apply with conflict resolution if len(successful_candidates) == 0: print(f"\n No reductions possible. Stopping.") break print(f"\n PHASE 2: Apply reductions with conflict resolution") print(f" " + "-" * 60) phase2_start = time.perf_counter() accepted = batched_conflict_resolution(model, evaluator, device, successful_candidates, base_magnitude) pass_reductions = len(accepted) phase2_time = time.perf_counter() - phase2_start print(f"\n Phase 2 complete: {pass_reductions} reductions applied in {format_time(phase2_time)}") # Pass summary total_reductions += pass_reductions current_magnitude = sum(t.abs().sum().item() for t in model.values()) current_nonzero = sum((t != 0).sum().item() for t in model.values()) pass_time = time.perf_counter() - pass_start reduction_pct = 100 * (1 - current_magnitude / base_magnitude) print(f"\n PASS {pass_num + 1} SUMMARY:") print(f" Reductions this pass: {pass_reductions}") print(f" Total reductions: {total_reductions}") print(f" Current magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)") print(f" Current non-zero: {current_nonzero}") print(f" Pass time: {format_time(pass_time)}") # Verify after pass print(f"\n Verifying model integrity...") fitness = check_fitness(model, evaluator, device) print(f" Fitness: {fitness:.6f} {'PASS' if fitness >= 0.9999 else 'FAIL'}") # Save checkpoint after each pass checkpoint_name = checkpoint_path.replace('.safetensors', f'_pass{pass_num + 1}.safetensors') print(f"\n Saving checkpoint: {checkpoint_name}") save_file(model, checkpoint_name) print(f" Saved. Magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)") # Also save as "latest" for easy access latest_path = checkpoint_path.replace('.safetensors', '_latest.safetensors') save_file(model, latest_path) print(f" Also saved as: {latest_path}") if pass_reductions == 0: print(f"\n No reductions achieved. Stopping early.") break # Final summary pruning_time = time.perf_counter() - pruning_start final_magnitude = sum(t.abs().sum().item() for t in model.values()) final_max = max(t.abs().max().item() for t in model.values()) final_nonzero = sum((t != 0).sum().item() for t in model.values()) reduction_pct = 100 * (1 - final_magnitude / base_magnitude) print("\n" + "=" * 80) print(" PRUNING COMPLETE") print("=" * 80) print(f"\n RESULTS:") print(f" Original magnitude: {base_magnitude:.0f}") print(f" Final magnitude: {final_magnitude:.0f}") print(f" Reduction: {reduction_pct:.2f}%") print(f" Total reductions: {total_reductions}") print(f" Original non-zero: {nonzero_params}") print(f" Final non-zero: {final_nonzero}") print(f" Zeros created: {nonzero_params - final_nonzero}") print(f" Max weight: {final_max:.0f}") print(f" Total time: {format_time(pruning_time)}") # Save print(f"\n SAVING to {checkpoint_path}...") save_file(model, checkpoint_path) print(f" Saved.") # Final verification print(f"\n FINAL VERIFICATION...") from safetensors import safe_open f = safe_open(checkpoint_path, framework='numpy') verify_model = {name: torch.tensor(f.get_tensor(name)).float() for name in f.keys()} verify_fitness = check_fitness(verify_model, evaluator, device) print(f" Fitness: {verify_fitness:.6f}") if verify_fitness >= 0.9999: print(f" STATUS: PASS") else: print(f" STATUS: FAIL - Model corrupted!") print("\n" + "=" * 80) return model MAX_BATCH_SIZE = 80000 if __name__ == "__main__": parser = argparse.ArgumentParser(description='Batched Weight Pruning') parser.add_argument('--passes', type=int, default=10, help='Maximum pruning passes (default: 10)') parser.add_argument('--batch_size', type=int, default=80000, help=f'Batch size for parallel evaluation (default: 80000, max: {MAX_BATCH_SIZE})') parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu (default: cuda)') parser.add_argument('--output', type=str, default='D:/8bit-threshold-computer/pruned.safetensors', help='Output path') args = parser.parse_args() if args.batch_size > MAX_BATCH_SIZE: print(f"WARNING: batch_size {args.batch_size} exceeds maximum {MAX_BATCH_SIZE}. Clamping.") args.batch_size = MAX_BATCH_SIZE print(f"\nStarting at {time.strftime('%Y-%m-%d %H:%M:%S')}\n") prune_weights( passes=args.passes, batch_size=args.batch_size, device=args.device, checkpoint_path=args.output ) print(f"\nFinished at {time.strftime('%Y-%m-%d %H:%M:%S')}")