|
|
""" |
|
|
BATCHED WEIGHT PRUNING (GPU-optimized) |
|
|
====================================== |
|
|
Phase 1: Batch eval all candidates in parallel |
|
|
Phase 2: Apply all successes at once, binary search if conflicts |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import time |
|
|
import argparse |
|
|
from safetensors.torch import save_file |
|
|
from eval import BatchedFitnessEvaluator, create_population, load_model |
|
|
|
|
|
torch.manual_seed(0) |
|
|
|
|
|
|
|
|
def format_time(seconds): |
|
|
if seconds < 60: |
|
|
return f"{seconds:.1f}s" |
|
|
elif seconds < 3600: |
|
|
return f"{seconds/60:.1f}m" |
|
|
else: |
|
|
return f"{seconds/3600:.1f}h" |
|
|
|
|
|
|
|
|
def format_eta(elapsed, done, total): |
|
|
if done == 0: |
|
|
return "calculating..." |
|
|
rate = done / elapsed |
|
|
remaining = (total - done) / rate |
|
|
return format_time(remaining) |
|
|
|
|
|
|
|
|
def apply_reductions(model, reductions): |
|
|
"""Apply a list of (name, flat_idx, shape, old_val) reductions.""" |
|
|
for name, flat_idx, shape, old_val in reductions: |
|
|
new_val = old_val - 1 if old_val > 0 else old_val + 1 |
|
|
flat = model[name].flatten() |
|
|
if flat[flat_idx].item() == old_val: |
|
|
flat[flat_idx] = new_val |
|
|
model[name] = flat.view(shape) |
|
|
|
|
|
|
|
|
def revert_reductions(model, reductions): |
|
|
"""Revert a list of reductions.""" |
|
|
for name, flat_idx, shape, old_val in reductions: |
|
|
flat = model[name].flatten() |
|
|
new_val = old_val - 1 if old_val > 0 else old_val + 1 |
|
|
if flat[flat_idx].item() == new_val: |
|
|
flat[flat_idx] = old_val |
|
|
model[name] = flat.view(shape) |
|
|
|
|
|
|
|
|
def check_fitness(model, evaluator, device): |
|
|
"""Check model fitness.""" |
|
|
torch.manual_seed(0) |
|
|
pop = create_population(model, 1, device) |
|
|
return evaluator.evaluate(pop, debug=False)[0].item() |
|
|
|
|
|
|
|
|
def sequential_conflict_resolution(model, evaluator, device, candidates, base_magnitude): |
|
|
""" |
|
|
Sequential fallback - tests and applies reductions one at a time. |
|
|
Slower but guarantees no interaction bugs. |
|
|
""" |
|
|
accepted = [] |
|
|
for i, (name, flat_idx, shape, old_val) in enumerate(candidates): |
|
|
apply_reductions(model, [(name, flat_idx, shape, old_val)]) |
|
|
fitness = check_fitness(model, evaluator, device) |
|
|
if fitness >= 0.9999: |
|
|
accepted.append((name, flat_idx, shape, old_val)) |
|
|
if (i + 1) % 50 == 0: |
|
|
current_mag = sum(t.abs().sum().item() for t in model.values()) |
|
|
reduction_pct = 100 * (1 - current_mag / base_magnitude) |
|
|
print(f" Sequential: {len(accepted)}/{i+1} accepted | mag={current_mag:.0f} (-{reduction_pct:.2f}%)") |
|
|
else: |
|
|
revert_reductions(model, [(name, flat_idx, shape, old_val)]) |
|
|
return accepted |
|
|
|
|
|
|
|
|
def batched_conflict_resolution(model, evaluator, device, candidates, base_magnitude): |
|
|
""" |
|
|
Batched binary search - evaluates multiple branches in parallel. |
|
|
Uses BFS instead of DFS to maximize batching opportunities. |
|
|
Verifies cumulative effect after each batch to prevent interaction bugs. |
|
|
""" |
|
|
if len(candidates) == 0: |
|
|
return [] |
|
|
|
|
|
|
|
|
print(f" Trying {len(candidates)} reductions at once...") |
|
|
apply_reductions(model, candidates) |
|
|
fitness = check_fitness(model, evaluator, device) |
|
|
|
|
|
if fitness >= 0.9999: |
|
|
current_mag = sum(t.abs().sum().item() for t in model.values()) |
|
|
reduction_pct = 100 * (1 - current_mag / base_magnitude) |
|
|
print(f" ALL {len(candidates)} OK | fitness={fitness:.6f} | " |
|
|
f"mag={current_mag:.0f} (-{reduction_pct:.2f}%)") |
|
|
return candidates |
|
|
|
|
|
|
|
|
revert_reductions(model, candidates) |
|
|
print(f" CONFLICT (fitness={fitness:.6f}), starting batched resolution...") |
|
|
|
|
|
accepted = [] |
|
|
|
|
|
pending = [(candidates, 0)] |
|
|
|
|
|
while pending: |
|
|
|
|
|
to_eval = [] |
|
|
for group, depth in pending: |
|
|
if len(group) == 0: |
|
|
continue |
|
|
elif len(group) == 1: |
|
|
to_eval.append((group, depth, 'single')) |
|
|
else: |
|
|
to_eval.append((group, depth, 'group')) |
|
|
|
|
|
pending = [] |
|
|
|
|
|
if not to_eval: |
|
|
break |
|
|
|
|
|
|
|
|
batch_size = len(to_eval) |
|
|
print(f" Batch evaluating {batch_size} groups...") |
|
|
|
|
|
|
|
|
pop = {} |
|
|
for name, tensor in model.items(): |
|
|
pop[name] = tensor.unsqueeze(0).expand(batch_size, *tensor.shape).clone().to(device) |
|
|
|
|
|
|
|
|
for idx, (group, depth, gtype) in enumerate(to_eval): |
|
|
for name, flat_idx, shape, old_val in group: |
|
|
new_val = old_val - 1 if old_val > 0 else old_val + 1 |
|
|
flat_view = pop[name][idx].flatten() |
|
|
|
|
|
base_val = model[name].flatten()[flat_idx].item() |
|
|
if base_val == old_val: |
|
|
flat_view[flat_idx] = new_val |
|
|
|
|
|
|
|
|
torch.manual_seed(0) |
|
|
fitnesses = evaluator.evaluate(pop, debug=False) |
|
|
|
|
|
|
|
|
batch_accepted = [] |
|
|
ok_count = 0 |
|
|
conflict_count = 0 |
|
|
fail_count = 0 |
|
|
|
|
|
for idx, (group, depth, gtype) in enumerate(to_eval): |
|
|
fit = fitnesses[idx].item() |
|
|
indent = " " + " " * depth |
|
|
|
|
|
if fit >= 0.9999: |
|
|
batch_accepted.append((group, depth, indent)) |
|
|
ok_count += len(group) |
|
|
else: |
|
|
if len(group) == 1: |
|
|
name, flat_idx, shape, old_val = group[0] |
|
|
print(f"{indent}[1/1] FAIL {name}[{flat_idx}] | fitness={fit:.6f}") |
|
|
fail_count += 1 |
|
|
else: |
|
|
mid = len(group) // 2 |
|
|
left = group[:mid] |
|
|
right = group[mid:] |
|
|
print(f"{indent}CONFLICT ({len(group)}) fitness={fit:.6f} -> split {len(left)}+{len(right)}") |
|
|
pending.append((left, depth + 1)) |
|
|
pending.append((right, depth + 1)) |
|
|
conflict_count += 1 |
|
|
|
|
|
|
|
|
all_batch_reductions = [] |
|
|
for group, depth, indent in batch_accepted: |
|
|
apply_reductions(model, group) |
|
|
all_batch_reductions.extend(group) |
|
|
|
|
|
|
|
|
if all_batch_reductions: |
|
|
verify_fitness = check_fitness(model, evaluator, device) |
|
|
if verify_fitness >= 0.9999: |
|
|
|
|
|
for group, depth, indent in batch_accepted: |
|
|
current_mag = sum(t.abs().sum().item() for t in model.values()) |
|
|
reduction_pct = 100 * (1 - current_mag / base_magnitude) |
|
|
if len(group) == 1: |
|
|
name, flat_idx, shape, old_val = group[0] |
|
|
print(f"{indent}[1/1] OK {name}[{flat_idx}] | mag={current_mag:.0f} (-{reduction_pct:.2f}%)") |
|
|
else: |
|
|
print(f"{indent}ALL {len(group)} OK | mag={current_mag:.0f} (-{reduction_pct:.2f}%)") |
|
|
accepted.extend(all_batch_reductions) |
|
|
print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed") |
|
|
else: |
|
|
|
|
|
print(f" INTERACTION BUG detected (batch fitness={verify_fitness:.6f})") |
|
|
print(f" Reverting {len(all_batch_reductions)} reductions, falling back to sequential...") |
|
|
revert_reductions(model, all_batch_reductions) |
|
|
|
|
|
|
|
|
seq_accepted = sequential_conflict_resolution( |
|
|
model, evaluator, device, all_batch_reductions, base_magnitude |
|
|
) |
|
|
accepted.extend(seq_accepted) |
|
|
print(f" Sequential fallback: {len(seq_accepted)}/{len(all_batch_reductions)} accepted") |
|
|
else: |
|
|
print(f" Batch result: {ok_count} accepted, {conflict_count} split, {fail_count} failed") |
|
|
|
|
|
return accepted |
|
|
|
|
|
|
|
|
def prune_weights( |
|
|
passes: int = 10, |
|
|
batch_size: int = 5000, |
|
|
device: str = 'cuda', |
|
|
checkpoint_path: str = "D:/8bit-threshold-computer/pruned.safetensors" |
|
|
): |
|
|
print("=" * 80) |
|
|
print(" BATCHED WEIGHT PRUNING (GPU-optimized)") |
|
|
print("=" * 80) |
|
|
print(f" Device: {device}") |
|
|
print(f" Batch size: {batch_size}") |
|
|
print(f" Max passes: {passes}") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
print("\n[1/4] LOADING MODEL...") |
|
|
load_start = time.perf_counter() |
|
|
model = load_model() |
|
|
load_time = time.perf_counter() - load_start |
|
|
|
|
|
n_params = sum(t.numel() for t in model.values()) |
|
|
n_tensors = len(model) |
|
|
base_magnitude = sum(t.abs().sum().item() for t in model.values()) |
|
|
base_max = max(t.abs().max().item() for t in model.values()) |
|
|
nonzero_params = sum((t != 0).sum().item() for t in model.values()) |
|
|
|
|
|
print(f" Loaded in {load_time:.2f}s") |
|
|
print(f" Tensors: {n_tensors}") |
|
|
print(f" Parameters: {n_params}") |
|
|
print(f" Non-zero parameters: {nonzero_params}") |
|
|
print(f" Total magnitude: {base_magnitude:.0f}") |
|
|
print(f" Max weight: {base_max:.0f}") |
|
|
|
|
|
|
|
|
print("\n[2/4] INITIALIZING EVALUATOR...") |
|
|
eval_start = time.perf_counter() |
|
|
evaluator = BatchedFitnessEvaluator(device=device) |
|
|
eval_time = time.perf_counter() - eval_start |
|
|
print(f" Initialized in {eval_time:.2f}s") |
|
|
|
|
|
|
|
|
print("\n[3/4] VERIFYING BASE MODEL...") |
|
|
initial_fitness = check_fitness(model, evaluator, device) |
|
|
print(f" Fitness: {initial_fitness:.6f}") |
|
|
|
|
|
if initial_fitness < 0.9999: |
|
|
print(f" ERROR: Base model fitness {initial_fitness:.6f} < 0.9999") |
|
|
return None |
|
|
|
|
|
print(f" STATUS: PASS") |
|
|
|
|
|
|
|
|
print("\n[4/4] BUILDING PARAMETER INDEX...") |
|
|
param_list = [] |
|
|
for name, tensor in model.items(): |
|
|
flat = tensor.flatten() |
|
|
for i in range(len(flat)): |
|
|
param_list.append((name, i, tensor.shape)) |
|
|
print(f" Indexed {len(param_list)} parameters") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print(" PRUNING STARTED") |
|
|
print("=" * 80) |
|
|
|
|
|
total_reductions = 0 |
|
|
pruning_start = time.perf_counter() |
|
|
|
|
|
for pass_num in range(passes): |
|
|
torch.manual_seed(0) |
|
|
pass_start = time.perf_counter() |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print(f" PASS {pass_num + 1}/{passes}") |
|
|
print(f"{'='*80}") |
|
|
|
|
|
|
|
|
candidates = [] |
|
|
for name, idx, shape in param_list: |
|
|
flat = model[name].flatten() |
|
|
val = flat[idx].item() |
|
|
if val != 0: |
|
|
candidates.append((name, idx, shape, val)) |
|
|
|
|
|
n_candidates = len(candidates) |
|
|
print(f"\n Candidates: {n_candidates} non-zero weights") |
|
|
|
|
|
if n_candidates == 0: |
|
|
print(f" No candidates remaining. Stopping.") |
|
|
break |
|
|
|
|
|
|
|
|
print(f"\n PHASE 1: Batch evaluation (testing each reduction independently)") |
|
|
print(f" " + "-" * 60) |
|
|
phase1_start = time.perf_counter() |
|
|
successful_candidates = [] |
|
|
n_batches = (n_candidates + batch_size - 1) // batch_size |
|
|
|
|
|
for batch_idx, batch_start_idx in enumerate(range(0, n_candidates, batch_size)): |
|
|
batch = candidates[batch_start_idx:batch_start_idx + batch_size] |
|
|
batch_len = len(batch) |
|
|
batch_start_time = time.perf_counter() |
|
|
|
|
|
|
|
|
pop = {} |
|
|
for name, tensor in model.items(): |
|
|
pop[name] = tensor.unsqueeze(0).expand(batch_len, *tensor.shape).clone().to(device) |
|
|
|
|
|
|
|
|
for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch): |
|
|
new_val = old_val - 1 if old_val > 0 else old_val + 1 |
|
|
flat_view = pop[name][pop_idx].flatten() |
|
|
flat_view[flat_idx] = new_val |
|
|
|
|
|
|
|
|
torch.manual_seed(0) |
|
|
if device == 'cuda': |
|
|
torch.cuda.synchronize() |
|
|
fitness = evaluator.evaluate(pop, debug=False) |
|
|
if device == 'cuda': |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
batch_successes = 0 |
|
|
for pop_idx, (name, flat_idx, shape, old_val) in enumerate(batch): |
|
|
if fitness[pop_idx].item() >= 0.9999: |
|
|
successful_candidates.append((name, flat_idx, shape, old_val)) |
|
|
batch_successes += 1 |
|
|
|
|
|
batch_time = time.perf_counter() - batch_start_time |
|
|
elapsed = time.perf_counter() - phase1_start |
|
|
done = batch_start_idx + batch_len |
|
|
eta = format_eta(elapsed, done, n_candidates) |
|
|
throughput = batch_len / batch_time |
|
|
|
|
|
print(f" Batch {batch_idx + 1}/{n_batches}: " |
|
|
f"{batch_successes}/{batch_len} passed ({100*batch_successes/batch_len:.1f}%) | " |
|
|
f"Total OK: {len(successful_candidates)} | " |
|
|
f"Progress: {done}/{n_candidates} ({100*done/n_candidates:.1f}%) | " |
|
|
f"Speed: {throughput:.0f}/s | " |
|
|
f"ETA: {eta}") |
|
|
|
|
|
phase1_time = time.perf_counter() - phase1_start |
|
|
print(f"\n Phase 1 complete: {len(successful_candidates)}/{n_candidates} candidates " |
|
|
f"({100*len(successful_candidates)/n_candidates:.1f}%) in {format_time(phase1_time)}") |
|
|
|
|
|
|
|
|
if len(successful_candidates) == 0: |
|
|
print(f"\n No reductions possible. Stopping.") |
|
|
break |
|
|
|
|
|
print(f"\n PHASE 2: Apply reductions with conflict resolution") |
|
|
print(f" " + "-" * 60) |
|
|
phase2_start = time.perf_counter() |
|
|
|
|
|
accepted = batched_conflict_resolution(model, evaluator, device, successful_candidates, base_magnitude) |
|
|
pass_reductions = len(accepted) |
|
|
|
|
|
phase2_time = time.perf_counter() - phase2_start |
|
|
print(f"\n Phase 2 complete: {pass_reductions} reductions applied in {format_time(phase2_time)}") |
|
|
|
|
|
|
|
|
total_reductions += pass_reductions |
|
|
current_magnitude = sum(t.abs().sum().item() for t in model.values()) |
|
|
current_nonzero = sum((t != 0).sum().item() for t in model.values()) |
|
|
pass_time = time.perf_counter() - pass_start |
|
|
reduction_pct = 100 * (1 - current_magnitude / base_magnitude) |
|
|
|
|
|
print(f"\n PASS {pass_num + 1} SUMMARY:") |
|
|
print(f" Reductions this pass: {pass_reductions}") |
|
|
print(f" Total reductions: {total_reductions}") |
|
|
print(f" Current magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)") |
|
|
print(f" Current non-zero: {current_nonzero}") |
|
|
print(f" Pass time: {format_time(pass_time)}") |
|
|
|
|
|
|
|
|
print(f"\n Verifying model integrity...") |
|
|
fitness = check_fitness(model, evaluator, device) |
|
|
print(f" Fitness: {fitness:.6f} {'PASS' if fitness >= 0.9999 else 'FAIL'}") |
|
|
|
|
|
|
|
|
checkpoint_name = checkpoint_path.replace('.safetensors', f'_pass{pass_num + 1}.safetensors') |
|
|
print(f"\n Saving checkpoint: {checkpoint_name}") |
|
|
save_file(model, checkpoint_name) |
|
|
print(f" Saved. Magnitude: {current_magnitude:.0f} (-{reduction_pct:.2f}%)") |
|
|
|
|
|
|
|
|
latest_path = checkpoint_path.replace('.safetensors', '_latest.safetensors') |
|
|
save_file(model, latest_path) |
|
|
print(f" Also saved as: {latest_path}") |
|
|
|
|
|
if pass_reductions == 0: |
|
|
print(f"\n No reductions achieved. Stopping early.") |
|
|
break |
|
|
|
|
|
|
|
|
pruning_time = time.perf_counter() - pruning_start |
|
|
final_magnitude = sum(t.abs().sum().item() for t in model.values()) |
|
|
final_max = max(t.abs().max().item() for t in model.values()) |
|
|
final_nonzero = sum((t != 0).sum().item() for t in model.values()) |
|
|
reduction_pct = 100 * (1 - final_magnitude / base_magnitude) |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print(" PRUNING COMPLETE") |
|
|
print("=" * 80) |
|
|
print(f"\n RESULTS:") |
|
|
print(f" Original magnitude: {base_magnitude:.0f}") |
|
|
print(f" Final magnitude: {final_magnitude:.0f}") |
|
|
print(f" Reduction: {reduction_pct:.2f}%") |
|
|
print(f" Total reductions: {total_reductions}") |
|
|
print(f" Original non-zero: {nonzero_params}") |
|
|
print(f" Final non-zero: {final_nonzero}") |
|
|
print(f" Zeros created: {nonzero_params - final_nonzero}") |
|
|
print(f" Max weight: {final_max:.0f}") |
|
|
print(f" Total time: {format_time(pruning_time)}") |
|
|
|
|
|
|
|
|
print(f"\n SAVING to {checkpoint_path}...") |
|
|
save_file(model, checkpoint_path) |
|
|
print(f" Saved.") |
|
|
|
|
|
|
|
|
print(f"\n FINAL VERIFICATION...") |
|
|
from safetensors import safe_open |
|
|
f = safe_open(checkpoint_path, framework='numpy') |
|
|
verify_model = {name: torch.tensor(f.get_tensor(name)).float() for name in f.keys()} |
|
|
verify_fitness = check_fitness(verify_model, evaluator, device) |
|
|
print(f" Fitness: {verify_fitness:.6f}") |
|
|
|
|
|
if verify_fitness >= 0.9999: |
|
|
print(f" STATUS: PASS") |
|
|
else: |
|
|
print(f" STATUS: FAIL - Model corrupted!") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
return model |
|
|
|
|
|
|
|
|
MAX_BATCH_SIZE = 80000 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description='Batched Weight Pruning') |
|
|
parser.add_argument('--passes', type=int, default=10, |
|
|
help='Maximum pruning passes (default: 10)') |
|
|
parser.add_argument('--batch_size', type=int, default=80000, |
|
|
help=f'Batch size for parallel evaluation (default: 80000, max: {MAX_BATCH_SIZE})') |
|
|
parser.add_argument('--device', type=str, default='cuda', |
|
|
help='Device: cuda or cpu (default: cuda)') |
|
|
parser.add_argument('--output', type=str, |
|
|
default='D:/8bit-threshold-computer/pruned.safetensors', |
|
|
help='Output path') |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.batch_size > MAX_BATCH_SIZE: |
|
|
print(f"WARNING: batch_size {args.batch_size} exceeds maximum {MAX_BATCH_SIZE}. Clamping.") |
|
|
args.batch_size = MAX_BATCH_SIZE |
|
|
|
|
|
print(f"\nStarting at {time.strftime('%Y-%m-%d %H:%M:%S')}\n") |
|
|
|
|
|
prune_weights( |
|
|
passes=args.passes, |
|
|
batch_size=args.batch_size, |
|
|
device=args.device, |
|
|
checkpoint_path=args.output |
|
|
) |
|
|
|
|
|
print(f"\nFinished at {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|