WCNegentropy's picture
πŸš€ Refined BitTransformerLM: Organized codebase with best practices
2ca4b28 verified
#!/usr/bin/env python3
"""
Test script for BTLM_Extensions
===============================
Quick test to verify all extensions are working properly.
"""
import sys
import os
import torch
import torch.nn as nn
# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
def test_imports():
"""Test that all modules can be imported."""
print("Testing imports...")
try:
from BTLM_Extensions import (
Muon, Lion, Adafactor,
configure_muon_optimizer,
configure_lion_optimizer,
configure_adafactor_optimizer,
RLEEncoder,
extension_manager,
get_package_info
)
print("βœ… All imports successful")
return True
except Exception as e:
print(f"❌ Import failed: {e}")
return False
def test_optimizers():
"""Test optimizer functionality."""
print("\nTesting optimizers...")
# Create a simple model
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 2)
)
try:
from BTLM_Extensions import (
configure_muon_optimizer,
configure_lion_optimizer,
configure_adafactor_optimizer
)
# Test each optimizer
optimizers_to_test = [
("muon", configure_muon_optimizer, {"lr": 1e-3}),
("lion", configure_lion_optimizer, {"lr": 1e-4}),
("adafactor", configure_adafactor_optimizer, {"lr": 1e-3}),
]
for name, config_fn, kwargs in optimizers_to_test:
try:
optimizer, scheduler = config_fn(model, total_steps=100, **kwargs)
# Test a training step
x = torch.randn(4, 10)
y = torch.randint(0, 2, (4,))
pred = model(x)
loss = nn.functional.cross_entropy(pred, y)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
optimizer.zero_grad()
print(f"βœ… {name.capitalize()} optimizer working")
except Exception as e:
print(f"❌ {name.capitalize()} optimizer failed: {e}")
return True
except Exception as e:
print(f"❌ Optimizer test failed: {e}")
return False
def test_rle_compression():
"""Test RLE compression."""
print("\nTesting RLE compression...")
try:
from BTLM_Extensions import RLEEncoder, benchmark_compression_schemes
# Create test data with patterns
test_data = torch.randint(0, 2, (50,))
# Add some runs for better compression
test_data[10:20] = 1
test_data[30:40] = 0
# Test different schemes
schemes = ["basic", "delta", "adaptive"]
for scheme in schemes:
try:
encoder = RLEEncoder(scheme=scheme)
compressed, metadata = encoder.encode(test_data)
reconstructed = encoder.decode(compressed, metadata)
# Check reconstruction
error = torch.mean((test_data.float() - reconstructed.float()) ** 2)
if error.item() < 1e-6:
print(f"βœ… RLE {scheme} scheme working (ratio: {metadata['compression_ratio']:.3f})")
else:
print(f"❌ RLE {scheme} scheme reconstruction error: {error.item()}")
except Exception as e:
print(f"❌ RLE {scheme} scheme failed: {e}")
# Test benchmark function
try:
results = benchmark_compression_schemes(test_data)
print(f"βœ… RLE benchmark completed ({len(results)} schemes tested)")
except Exception as e:
print(f"❌ RLE benchmark failed: {e}")
return True
except Exception as e:
print(f"❌ RLE compression test failed: {e}")
return False
def test_integration():
"""Test integration features."""
print("\nTesting integration features...")
try:
from BTLM_Extensions import extension_manager, get_package_info
# Test package info
info = get_package_info()
print(f"βœ… Package info: {info['name']} v{info['version']}")
# Test extension manager
optimizers = extension_manager.SUPPORTED_OPTIMIZERS
compression = extension_manager.SUPPORTED_COMPRESSION
print(f"βœ… Extension manager: {len(optimizers)} optimizers, {len(compression)} compression schemes")
return True
except Exception as e:
print(f"❌ Integration test failed: {e}")
return False
def test_bittransformerlm_integration():
"""Test integration with BitTransformerLM if available."""
print("\nTesting BitTransformerLM integration...")
try:
from bit_transformer import BitTransformerLM
from BTLM_Extensions import configure_optimizer
# Create a small BitTransformerLM model
model = BitTransformerLM(
d_model=64,
nhead=4,
num_layers=2,
dim_feedforward=128,
max_seq_len=32
)
# Test optimizer integration
optimizer, scheduler = configure_optimizer("muon", model, lr=1e-3, total_steps=10)
# Simple forward pass
test_bits = torch.randint(0, 2, (2, 16))
logits, telemetry = model(test_bits)
# Simple training step
pred = logits[:, :-1, :].reshape(-1, 2)
target = test_bits[:, 1:].reshape(-1)
loss = nn.functional.cross_entropy(pred, target)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
print(f"βœ… BitTransformerLM integration working (loss: {loss.item():.4f})")
return True
except ImportError:
print("⚠️ BitTransformerLM not available, skipping integration test")
return True
except Exception as e:
print(f"❌ BitTransformerLM integration failed: {e}")
return False
def main():
"""Run all tests."""
print("BTLM_Extensions Test Suite")
print("=" * 40)
tests = [
test_imports,
test_optimizers,
test_rle_compression,
test_integration,
test_bittransformerlm_integration,
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
except Exception as e:
print(f"❌ Test {test.__name__} crashed: {e}")
print("\n" + "=" * 40)
print(f"Test Results: {passed}/{total} passed")
if passed == total:
print("πŸŽ‰ All tests passed! Extensions are working correctly.")
return 0
else:
print("⚠️ Some tests failed. Check the output above.")
return 1
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)