|
|
|
|
|
"""
|
|
|
Verification script to demonstrate all implemented functionality.
|
|
|
Run this to see layers.py and packing.py in action!
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from bitlinear import BitLinear, MultiTernaryLinear, convert_linear_to_bitlinear
|
|
|
from bitlinear.packing import (
|
|
|
pack_ternary_base3,
|
|
|
unpack_ternary_base3,
|
|
|
estimate_memory_savings,
|
|
|
)
|
|
|
|
|
|
|
|
|
def demo_bitlinear():
|
|
|
"""Demonstrate BitLinear layer."""
|
|
|
print("=" * 70)
|
|
|
print("1. BitLinear Layer Demo")
|
|
|
print("=" * 70)
|
|
|
|
|
|
|
|
|
layer = BitLinear(512, 256, bias=True)
|
|
|
print(f"β Created BitLinear(512 β 256)")
|
|
|
print(f" - W_ternary shape: {layer.W_ternary.shape}")
|
|
|
print(f" - Gamma shape: {layer.gamma.shape}")
|
|
|
print(f" - Unique weight values: {sorted(layer.W_ternary.unique().tolist())}")
|
|
|
|
|
|
|
|
|
x = torch.randn(16, 512)
|
|
|
y = layer(x)
|
|
|
print(f"\nβ Forward pass: {x.shape} β {y.shape}")
|
|
|
|
|
|
|
|
|
linear = nn.Linear(512, 256)
|
|
|
bitlinear = BitLinear.from_linear(linear)
|
|
|
print(f"β Converted nn.Linear to BitLinear")
|
|
|
print()
|
|
|
|
|
|
|
|
|
def demo_multi_ternary():
|
|
|
"""Demonstrate MultiTernaryLinear layer."""
|
|
|
print("=" * 70)
|
|
|
print("2. MultiTernaryLinear Layer Demo")
|
|
|
print("=" * 70)
|
|
|
|
|
|
|
|
|
for k in [1, 2, 4]:
|
|
|
layer = MultiTernaryLinear(256, 128, k=k, bias=True)
|
|
|
print(f"β MultiTernaryLinear(256 β 128, k={k})")
|
|
|
print(f" - W_ternary shape: {layer.W_ternary.shape}")
|
|
|
print(f" - Gammas shape: {layer.gammas.shape}")
|
|
|
|
|
|
|
|
|
print("\nβ Approximation quality test:")
|
|
|
linear = nn.Linear(128, 128)
|
|
|
x = torch.randn(8, 128)
|
|
|
dense_out = linear(x)
|
|
|
|
|
|
errors = []
|
|
|
for k in [1, 2, 4]:
|
|
|
multi = MultiTernaryLinear.from_linear(linear, k=k)
|
|
|
ternary_out = multi(x)
|
|
|
error = torch.norm(dense_out - ternary_out).item()
|
|
|
errors.append(error)
|
|
|
print(f" - k={k}: reconstruction error = {error:.4f}")
|
|
|
|
|
|
print(f" - Error decreases with k: {errors[0] > errors[1] > errors[2]}")
|
|
|
print()
|
|
|
|
|
|
|
|
|
def demo_model_conversion():
|
|
|
"""Demonstrate model conversion utility."""
|
|
|
print("=" * 70)
|
|
|
print("3. Model Conversion Utility Demo")
|
|
|
print("=" * 70)
|
|
|
|
|
|
|
|
|
class SimpleModel(nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.fc1 = nn.Linear(128, 256)
|
|
|
self.relu = nn.ReLU()
|
|
|
self.fc2 = nn.Linear(256, 128)
|
|
|
self.fc3 = nn.Linear(128, 10)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.relu(self.fc1(x))
|
|
|
x = self.relu(self.fc2(x))
|
|
|
return self.fc3(x)
|
|
|
|
|
|
model = SimpleModel()
|
|
|
|
|
|
|
|
|
linear_count = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
|
|
print(f"β Original model: {linear_count} Linear layers")
|
|
|
|
|
|
|
|
|
model_converted = convert_linear_to_bitlinear(model, inplace=False)
|
|
|
bitlinear_count = sum(1 for m in model_converted.modules() if isinstance(m, BitLinear))
|
|
|
print(f"β Converted model: {bitlinear_count} BitLinear layers")
|
|
|
|
|
|
|
|
|
x = torch.randn(4, 128)
|
|
|
y = model_converted(x)
|
|
|
print(f"β Forward pass works: {x.shape} β {y.shape}")
|
|
|
print()
|
|
|
|
|
|
|
|
|
def demo_packing():
|
|
|
"""Demonstrate base-3 packing."""
|
|
|
print("=" * 70)
|
|
|
print("4. Base-3 Packing Demo")
|
|
|
print("=" * 70)
|
|
|
|
|
|
|
|
|
W = torch.tensor([
|
|
|
[-1, 0, 1, -1, 0],
|
|
|
[1, 1, -1, 0, 1],
|
|
|
[0, -1, 1, 1, -1],
|
|
|
], dtype=torch.float32)
|
|
|
|
|
|
print(f"β Original ternary weights shape: {W.shape}")
|
|
|
print(f" - Float32 memory: {W.numel() * 4} bytes")
|
|
|
|
|
|
|
|
|
packed, original_shape = pack_ternary_base3(W)
|
|
|
print(f"\nβ Packed into uint8 tensor")
|
|
|
print(f" - Packed shape: {packed.shape}")
|
|
|
print(f" - Packed memory: {packed.numel()} bytes")
|
|
|
print(f" - Compression: {W.numel() * 4 / packed.numel():.2f}x")
|
|
|
|
|
|
|
|
|
W_unpacked = unpack_ternary_base3(packed, original_shape)
|
|
|
print(f"\nβ Unpacked back to ternary")
|
|
|
print(f" - Unpacked shape: {W_unpacked.shape}")
|
|
|
print(f" - Perfect round-trip: {torch.allclose(W, W_unpacked)}")
|
|
|
print()
|
|
|
|
|
|
|
|
|
def demo_memory_estimation():
|
|
|
"""Demonstrate memory savings estimation."""
|
|
|
print("=" * 70)
|
|
|
print("5. Memory Savings Estimation")
|
|
|
print("=" * 70)
|
|
|
|
|
|
configs = [
|
|
|
(768, 3072, 1, "Single Transformer FFN layer"),
|
|
|
(768, 3072, 12, "BERT-base (12 layers)"),
|
|
|
(1024, 4096, 24, "BERT-large (24 layers)"),
|
|
|
]
|
|
|
|
|
|
for in_dim, out_dim, num_layers, description in configs:
|
|
|
stats = estimate_memory_savings(in_dim, out_dim, num_layers)
|
|
|
print(f"\nβ {description}")
|
|
|
print(f" Configuration: {in_dim} β {out_dim} Γ {num_layers} layers")
|
|
|
print(f" Float32 memory: {stats['float32_bytes'] / 1e6:.2f} MB")
|
|
|
print(f" Packed memory: {stats['packed_bytes'] / 1e6:.2f} MB")
|
|
|
print(f" Savings: {stats['savings_bytes'] / 1e6:.2f} MB")
|
|
|
print(f" Compression: {stats['compression_ratio']:.2f}x")
|
|
|
print()
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Run all demos."""
|
|
|
print("\n" + "=" * 70)
|
|
|
print(" BitLinear Implementation Verification")
|
|
|
print(" All functionality implemented and working!")
|
|
|
print("=" * 70)
|
|
|
print()
|
|
|
|
|
|
demo_bitlinear()
|
|
|
demo_multi_ternary()
|
|
|
demo_model_conversion()
|
|
|
demo_packing()
|
|
|
demo_memory_estimation()
|
|
|
|
|
|
print("=" * 70)
|
|
|
print(" β All implementations verified!")
|
|
|
print(" β Ready for C++/CUDA optimization")
|
|
|
print("=" * 70)
|
|
|
print()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|