8bit-threshold-computer / prune_weights.py
CharlesCNorton
Remove eval/ folder, move prune_weights.py to root
90f3f79
"""
BATCHED WEIGHT PRUNING (GPU-optimized)
======================================
Phase 1: Batch eval all candidates in parallel
Phase 2: Apply all successes at once, binary search if conflicts
"""
import torch
import time
import argparse
from safetensors.torch import save_file
from eval import BatchedFitnessEvaluator, create_population, load_model
torch.manual_seed(0)
def format_time(seconds):
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
return f"{seconds/60:.1f}m"
else:
return f"{seconds/3600:.1f}h"
def format_eta(elapsed, done, total):
if done == 0:
return "calculating..."
rate = done / elapsed
remaining = (total - done) / rate
return format_time(remaining)
def apply_reductions(model, reductions):
"""Apply a list of (name, flat_idx, shape, old_val) reductions."""
for name, flat_idx, shape, old_val in reductions:
new_val = old_val - 1 if old_val > 0 else old_val + 1
flat = model[name].flatten()
if flat[flat_idx].item() == old_val:
flat[flat_idx] = new_val
model[name] = flat.view(shape)
def revert_reductions(model, reductions):
"""Revert a list of reductions."""
for name, flat_idx, shape, old_val in reductions:
flat = model[name].flatten()
new_val = old_val - 1 if old_val > 0 else old_val + 1
if flat[flat_idx].item() == new_val:
flat[flat_idx] = old_val
model[name] = flat.view(shape)
def check_fitness(model, evaluator, device):
"""Check model fitness."""
torch.manual_seed(0)
pop = create_population(model, 1, device)
return evaluator.evaluate(pop, debug=False)[0].item()
def sequential_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
"""
Sequential fallback - tests and applies reductions one at a time.
Slower but guarantees no interaction bugs.
"""
accepted = []
for i, (name, flat_idx, shape, old_val) in enumerate(candidates):
apply_reductions(model, [(name, flat_idx, shape, old_val)])
fitness = check_fitness(model, evaluator, device)
if fitness >= 0.9999:
accepted.append((name, flat_idx, shape, old_val))
if (i + 1) % 50 == 0:
current_mag = sum(t.abs().sum().item() for t in model.values())
reduction_pct = 100 * (1 - current_mag / base_magnitude)
print(f" Sequential: {len(accepted)}/{i+1} accepted | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
else:
revert_reductions(model, [(name, flat_idx, shape, old_val)])
return accepted
def batched_conflict_resolution(model, evaluator, device, candidates, base_magnitude):
"""
Batched binary search - evaluates multiple branches in parallel.
Uses BFS instead of DFS to maximize batching opportunities.
Verifies cumulative effect after each batch to prevent interaction bugs.
"""
if len(candidates) == 0:
return []
# First try all at once
print(f" Trying {len(candidates)} reductions at once...")
apply_reductions(model, candidates)
fitness = check_fitness(model, evaluator, device)
if fitness >= 0.9999:
current_mag = sum(t.abs().sum().item() for t in model.values())
reduction_pct = 100 * (1 - current_mag / base_magnitude)
print(f" ALL {len(candidates)} OK | fitness={fitness:.6f} | "
f"mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
return candidates
# Conflict - revert and use batched BFS
revert_reductions(model, candidates)
print(f" CONFLICT (fitness={fitness:.6f}), starting batched resolution...")
accepted = []
# Queue of (candidate_list, depth) to process
pending = [(candidates, 0)]
while pending:
# Collect all pending groups for batch evaluation
to_eval = []
for group, depth in pending:
if len(group) == 0:
continue
elif len(group) == 1:
to_eval.append((group, depth, 'single'))
else:
to_eval.append((group, depth, 'group'))
pending = []
if not to_eval:
break
# Build batch: create model variants for each group
batch_size = len(to_eval)
print(f" Batch evaluating {batch_size} groups...")
# Create population for batch eval
pop = {}
for name, tensor in model.items():
pop[name] = tensor.unsqueeze(0).expand(batch_size, *tensor.shape).clone().to(device)
# Apply each group's reductions to its population slot
for idx, (group, depth, gtype) in enumerate(to_eval):
for name, flat_idx, shape, old_val in group:
new_val = old_val - 1 if old_val > 0 else old_val + 1
flat_view = pop[name][idx].flatten()
# Check if not already modified in base model
base_val = model[name].flatten()[flat_idx].item()
if base_val == old_val:
flat_view[flat_idx] = new_val
# Batch evaluate
torch.manual_seed(0)
fitnesses = evaluator.evaluate(pop, debug=False)
# Process results - collect accepted groups first, then verify
batch_accepted = []
ok_count = 0
conflict_count = 0
fail_count = 0
for idx, (group, depth, gtype) in enumerate(to_eval):
fit = fitnesses[idx].item()
indent = " " + " " * depth
if fit >= 0.9999:
batch_accepted.append((group, depth, indent))
ok_count += len(group)
else:
if len(group) == 1:
name, flat_idx, shape, old_val = group[0]
print(f"{indent}[1/1] FAIL {name}[{flat_idx}] | fitness={fit:.6f}")
fail_count += 1
else:
mid = len(group) // 2
left = group[:mid]
right = group[mid:]
print(f"{indent}CONFLICT ({len(group)}) fitness={fit:.6f} -> split {len(left)}+{len(right)}")
pending.append((left, depth + 1))
pending.append((right, depth + 1))
conflict_count += 1
# Apply all batch-accepted reductions
all_batch_reductions = []
for group, depth, indent in batch_accepted:
apply_reductions(model, group)
all_batch_reductions.extend(group)
# Verify cumulative effect
if all_batch_reductions:
verify_fitness = check_fitness(model, evaluator, device)
if verify_fitness >= 0.9999:
# All good - commit these reductions
for group, depth, indent in batch_accepted:
current_mag = sum(t.abs().sum().item() for t in model.values())
reduction_pct = 100 * (1 - current_mag / base_magnitude)
if len(group) == 1:
name, flat_idx, shape, old_val = group[0]
print(f"{indent}[1/1] OK {name}[{flat_idx}] | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
else:
print(f"{indent}ALL {len(group)} OK | mag={current_mag:.0f} (-{reduction_pct:.2f}%)")
accepted.extend(all_batch_reductions)
print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
else:
# Interaction bug detected - revert and use sequential fallback
print(f" INTERACTION BUG detected (batch fitness={verify_fitness:.6f})")
print(f" Reverting {len(all_batch_reductions)} reductions, falling back to sequential...")
revert_reductions(model, all_batch_reductions)
# Process each group sequentially
seq_accepted = sequential_conflict_resolution(
model, evaluator, device, all_batch_reductions, base_magnitude
)
accepted.extend(seq_accepted)
print(f" Sequential fallback: {len(seq_accepted)}/{len(all_batch_reductions)} accepted")
else:
print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed")
return accepted
def prune_weights(
passes: int = 10,
batch_size: int = 5000,
device: str = 'cuda',
checkpoint_path: str = "D:/8bit-threshold-computer/pruned.safetensors"
):
print("=" * 80)
print(" BATCHED WEIGHT PRUNING (GPU-optimized)")
print("=" * 80)
print(f" Device: {device}")
print(f" Batch size: {batch_size}")
print(f" Max passes: {passes}")
print("=" * 80)
# Load model
print("\n[1/4] LOADING MODEL...")
load_start = time.perf_counter()
model = load_model()
load_time = time.perf_counter() - load_start
n_params = sum(t.numel() for t in model.values())
n_tensors = len(model)
base_magnitude = sum(t.abs().sum().item() for t in model.values())
base_max = max(t.abs().max().item() for t in model.values())
nonzero_params = sum((t != 0).sum().item() for t in model.values())
print(f" Loaded in {load_time:.2f}s")
print(f" Tensors: {n_tensors}")
print(f" Parameters: {n_params}")
print(f" Non-zero parameters: {nonzero_params}")
print(f" Total magnitude: {base_magnitude:.0f}")
print(f" Max weight: {base_max:.0f}")
# Initialize evaluator
print("\n[2/4] INITIALIZING EVALUATOR...")
eval_start = time.perf_counter()
evaluator = BatchedFitnessEvaluator(device=device)
eval_time = time.perf_counter() - eval_start
print(f" Initialized in {eval_time:.2f}s")
# Verify initial fitness
print("\n[3/4] VERIFYING BASE MODEL...")
initial_fitness = check_fitness(model, evaluator, device)
print(f" Fitness: {initial_fitness:.6f}")
if initial_fitness < 0.9999:
print(f" ERROR: Base model fitness {initial_fitness:.6f} < 0.9999")
return None
print(f" STATUS: PASS")
# Build parameter list
print("\n[4/4] BUILDING PARAMETER INDEX...")
param_list = []
for name, tensor in model.items():
flat = tensor.flatten()
for i in range(len(flat)):
param_list.append((name, i, tensor.shape))
print(f" Indexed {len(param_list)} parameters")
# Main pruning loop
print("\n" + "=" * 80)
print(" PRUNING STARTED")
print("=" * 80)
total_reductions = 0
pruning_start = time.perf_counter()
for pass_num in range(passes):
torch.manual_seed(0)
pass_start = time.perf_counter()
print(f"\n{'='*80}")
print(f" PASS {pass_num + 1}/{passes}")
print(f"{'='*80}")
# Count candidates
candidates = []
for name, idx, shape in param_list:
flat = model[name].flatten()
val = flat[idx].item()
if val != 0:
candidates.append((name, idx, shape, val))
n_candidates = len(candidates)
print(f"\n Candidates: {n_candidates} non-zero weights")
if n_candidates == 0:
print(f" No candidates remaining. Stopping.")
break
# Phase 1: Batch evaluation
print(f"\n PHASE 1: Batch evaluation (testing each reduction independently)")
print(f" " + "-" * 60)
phase1_start = time.perf_counter()
successful_candidates = []
n_batches = (n_candidates + batch_size - 1) // batch_size
for batch_idx, batch_start_idx in enumerate(range(0, n_candidates, batch_size)):
batch = candidates[batch_start_idx:batch_start_idx + batch_size]
batch_len = len(batch)
batch_start_time = time.perf_counter()
# Build population
pop = {}
for name, tensor in model.items():
pop[name] = tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device)
# Apply reductions
for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
new_val = old_val - 1 if old_val > 0 else old_val + 1
flat_view = pop[name][pop_idx].flatten()
flat_view[flat_idx] = new_val
# Evaluate
torch.manual_seed(0)
if device == 'cuda':
torch.cuda.synchronize()
fitness = evaluator.evaluate(pop, debug=False)
if device == 'cuda':
torch.cuda.synchronize()
# Collect successes
batch_successes = 0
for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch):
if fitness[pop_idx].item() >= 0.9999:
successful_candidates.append((name, flat_idx, shape, old_val))
batch_successes += 1
batch_time = time.perf_counter() - batch_start_time
elapsed = time.perf_counter() - phase1_start
done = batch_start_idx + batch_len
eta = format_eta(elapsed, done, n_candidates)
throughput = batch_len / batch_time
print(f" Batch {batch_idx + 1}/{n_batches}: "
f"{batch_successes}/{batch_len} passed ({100*batch_successes/batch_len:.1f}%) | "
f"Total OK: {len(successful_candidates)} | "
f"Progress: {done}/{n_candidates} ({100*done/n_candidates:.1f}%) | "
f"Speed: {throughput:.0f}/s | "
f"ETA: {eta}")
phase1_time = time.perf_counter() - phase1_start
print(f"\n Phase 1 complete: {len(successful_candidates)}/{n_candidates} candidates "
f"({100*len(successful_candidates)/n_candidates:.1f}%) in {format_time(phase1_time)}")
# Phase 2: Apply with conflict resolution
if len(successful_candidates) == 0:
print(f"\n No reductions possible. Stopping.")
break
print(f"\n PHASE 2: Apply reductions with conflict resolution")
print(f" " + "-" * 60)
phase2_start = time.perf_counter()
accepted = batched_conflict_resolution(model, evaluator, device, successful_candidates, base_magnitude)
pass_reductions = len(accepted)
phase2_time = time.perf_counter() - phase2_start
print(f"\n Phase 2 complete: {pass_reductions} reductions applied in {format_time(phase2_time)}")
# Pass summary
total_reductions += pass_reductions
current_magnitude = sum(t.abs().sum().item() for t in model.values())
current_nonzero = sum((t != 0).sum().item() for t in model.values())
pass_time = time.perf_counter() - pass_start
reduction_pct = 100 * (1 - current_magnitude / base_magnitude)
print(f"\n PASS {pass_num + 1} SUMMARY:")
print(f" Reductions this pass: {pass_reductions}")
print(f" Total reductions: {total_reductions}")
print(f" Current magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
print(f" Current non-zero: {current_nonzero}")
print(f" Pass time: {format_time(pass_time)}")
# Verify after pass
print(f"\n Verifying model integrity...")
fitness = check_fitness(model, evaluator, device)
print(f" Fitness: {fitness:.6f} {'PASS' if fitness >= 0.9999 else 'FAIL'}")
# Save checkpoint after each pass
checkpoint_name = checkpoint_path.replace('.safetensors', f'_pass{pass_num + 1}.safetensors')
print(f"\n Saving checkpoint: {checkpoint_name}")
save_file(model, checkpoint_name)
print(f" Saved. Magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)")
# Also save as "latest" for easy access
latest_path = checkpoint_path.replace('.safetensors', '_latest.safetensors')
save_file(model, latest_path)
print(f" Also saved as: {latest_path}")
if pass_reductions == 0:
print(f"\n No reductions achieved. Stopping early.")
break
# Final summary
pruning_time = time.perf_counter() - pruning_start
final_magnitude = sum(t.abs().sum().item() for t in model.values())
final_max = max(t.abs().max().item() for t in model.values())
final_nonzero = sum((t != 0).sum().item() for t in model.values())
reduction_pct = 100 * (1 - final_magnitude / base_magnitude)
print("\n" + "=" * 80)
print(" PRUNING COMPLETE")
print("=" * 80)
print(f"\n RESULTS:")
print(f" Original magnitude: {base_magnitude:.0f}")
print(f" Final magnitude: {final_magnitude:.0f}")
print(f" Reduction: {reduction_pct:.2f}%")
print(f" Total reductions: {total_reductions}")
print(f" Original non-zero: {nonzero_params}")
print(f" Final non-zero: {final_nonzero}")
print(f" Zeros created: {nonzero_params - final_nonzero}")
print(f" Max weight: {final_max:.0f}")
print(f" Total time: {format_time(pruning_time)}")
# Save
print(f"\n SAVING to {checkpoint_path}...")
save_file(model, checkpoint_path)
print(f" Saved.")
# Final verification
print(f"\n FINAL VERIFICATION...")
from safetensors import safe_open
f = safe_open(checkpoint_path, framework='numpy')
verify_model = {name: torch.tensor(f.get_tensor(name)).float() for name in f.keys()}
verify_fitness = check_fitness(verify_model, evaluator, device)
print(f" Fitness: {verify_fitness:.6f}")
if verify_fitness >= 0.9999:
print(f" STATUS: PASS")
else:
print(f" STATUS: FAIL - Model corrupted!")
print("\n" + "=" * 80)
return model
MAX_BATCH_SIZE = 80000
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Batched Weight Pruning')
parser.add_argument('--passes', type=int, default=10,
help='Maximum pruning passes (default: 10)')
parser.add_argument('--batch_size', type=int, default=80000,
help=f'Batch size for parallel evaluation (default: 80000, max: {MAX_BATCH_SIZE})')
parser.add_argument('--device', type=str, default='cuda',
help='Device: cuda or cpu (default: cuda)')
parser.add_argument('--output', type=str,
default='D:/8bit-threshold-computer/pruned.safetensors',
help='Output path')
args = parser.parse_args()
if args.batch_size > MAX_BATCH_SIZE:
print(f"WARNING: batch_size {args.batch_size} exceeds maximum {MAX_BATCH_SIZE}. Clamping.")
args.batch_size = MAX_BATCH_SIZE
print(f"\nStarting at {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
prune_weights(
passes=args.passes,
batch_size=args.batch_size,
device=args.device,
checkpoint_path=args.output
)
print(f"\nFinished at {time.strftime('%Y-%m-%d %H:%M:%S')}")