|
|
|
|
|
""" |
|
|
Test script for BTLM_Extensions |
|
|
=============================== |
|
|
|
|
|
Quick test to verify all extensions are working properly. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
def test_imports(): |
|
|
"""Test that all modules can be imported.""" |
|
|
print("Testing imports...") |
|
|
|
|
|
try: |
|
|
from BTLM_Extensions import ( |
|
|
Muon, Lion, Adafactor, |
|
|
configure_muon_optimizer, |
|
|
configure_lion_optimizer, |
|
|
configure_adafactor_optimizer, |
|
|
RLEEncoder, |
|
|
extension_manager, |
|
|
get_package_info |
|
|
) |
|
|
print("β
All imports successful") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Import failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_optimizers(): |
|
|
"""Test optimizer functionality.""" |
|
|
print("\nTesting optimizers...") |
|
|
|
|
|
|
|
|
model = nn.Sequential( |
|
|
nn.Linear(10, 20), |
|
|
nn.ReLU(), |
|
|
nn.Linear(20, 2) |
|
|
) |
|
|
|
|
|
try: |
|
|
from BTLM_Extensions import ( |
|
|
configure_muon_optimizer, |
|
|
configure_lion_optimizer, |
|
|
configure_adafactor_optimizer |
|
|
) |
|
|
|
|
|
|
|
|
optimizers_to_test = [ |
|
|
("muon", configure_muon_optimizer, {"lr": 1e-3}), |
|
|
("lion", configure_lion_optimizer, {"lr": 1e-4}), |
|
|
("adafactor", configure_adafactor_optimizer, {"lr": 1e-3}), |
|
|
] |
|
|
|
|
|
for name, config_fn, kwargs in optimizers_to_test: |
|
|
try: |
|
|
optimizer, scheduler = config_fn(model, total_steps=100, **kwargs) |
|
|
|
|
|
|
|
|
x = torch.randn(4, 10) |
|
|
y = torch.randint(0, 2, (4,)) |
|
|
|
|
|
pred = model(x) |
|
|
loss = nn.functional.cross_entropy(pred, y) |
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
if scheduler: |
|
|
scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
print(f"β
{name.capitalize()} optimizer working") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β {name.capitalize()} optimizer failed: {e}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Optimizer test failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_rle_compression(): |
|
|
"""Test RLE compression.""" |
|
|
print("\nTesting RLE compression...") |
|
|
|
|
|
try: |
|
|
from BTLM_Extensions import RLEEncoder, benchmark_compression_schemes |
|
|
|
|
|
|
|
|
test_data = torch.randint(0, 2, (50,)) |
|
|
|
|
|
test_data[10:20] = 1 |
|
|
test_data[30:40] = 0 |
|
|
|
|
|
|
|
|
schemes = ["basic", "delta", "adaptive"] |
|
|
|
|
|
for scheme in schemes: |
|
|
try: |
|
|
encoder = RLEEncoder(scheme=scheme) |
|
|
compressed, metadata = encoder.encode(test_data) |
|
|
reconstructed = encoder.decode(compressed, metadata) |
|
|
|
|
|
|
|
|
error = torch.mean((test_data.float() - reconstructed.float()) ** 2) |
|
|
|
|
|
if error.item() < 1e-6: |
|
|
print(f"β
RLE {scheme} scheme working (ratio: {metadata['compression_ratio']:.3f})") |
|
|
else: |
|
|
print(f"β RLE {scheme} scheme reconstruction error: {error.item()}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β RLE {scheme} scheme failed: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
results = benchmark_compression_schemes(test_data) |
|
|
print(f"β
RLE benchmark completed ({len(results)} schemes tested)") |
|
|
except Exception as e: |
|
|
print(f"β RLE benchmark failed: {e}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β RLE compression test failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_integration(): |
|
|
"""Test integration features.""" |
|
|
print("\nTesting integration features...") |
|
|
|
|
|
try: |
|
|
from BTLM_Extensions import extension_manager, get_package_info |
|
|
|
|
|
|
|
|
info = get_package_info() |
|
|
print(f"β
Package info: {info['name']} v{info['version']}") |
|
|
|
|
|
|
|
|
optimizers = extension_manager.SUPPORTED_OPTIMIZERS |
|
|
compression = extension_manager.SUPPORTED_COMPRESSION |
|
|
print(f"β
Extension manager: {len(optimizers)} optimizers, {len(compression)} compression schemes") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Integration test failed: {e}") |
|
|
return False |
|
|
|
|
|
def test_bittransformerlm_integration(): |
|
|
"""Test integration with BitTransformerLM if available.""" |
|
|
print("\nTesting BitTransformerLM integration...") |
|
|
|
|
|
try: |
|
|
from bit_transformer import BitTransformerLM |
|
|
from BTLM_Extensions import configure_optimizer |
|
|
|
|
|
|
|
|
model = BitTransformerLM( |
|
|
d_model=64, |
|
|
nhead=4, |
|
|
num_layers=2, |
|
|
dim_feedforward=128, |
|
|
max_seq_len=32 |
|
|
) |
|
|
|
|
|
|
|
|
optimizer, scheduler = configure_optimizer("muon", model, lr=1e-3, total_steps=10) |
|
|
|
|
|
|
|
|
test_bits = torch.randint(0, 2, (2, 16)) |
|
|
logits, telemetry = model(test_bits) |
|
|
|
|
|
|
|
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
|
target = test_bits[:, 1:].reshape(-1) |
|
|
loss = nn.functional.cross_entropy(pred, target) |
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
if scheduler: |
|
|
scheduler.step() |
|
|
|
|
|
print(f"β
BitTransformerLM integration working (loss: {loss.item():.4f})") |
|
|
return True |
|
|
|
|
|
except ImportError: |
|
|
print("β οΈ BitTransformerLM not available, skipping integration test") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β BitTransformerLM integration failed: {e}") |
|
|
return False |
|
|
|
|
|
def main(): |
|
|
"""Run all tests.""" |
|
|
print("BTLM_Extensions Test Suite") |
|
|
print("=" * 40) |
|
|
|
|
|
tests = [ |
|
|
test_imports, |
|
|
test_optimizers, |
|
|
test_rle_compression, |
|
|
test_integration, |
|
|
test_bittransformerlm_integration, |
|
|
] |
|
|
|
|
|
passed = 0 |
|
|
total = len(tests) |
|
|
|
|
|
for test in tests: |
|
|
try: |
|
|
if test(): |
|
|
passed += 1 |
|
|
except Exception as e: |
|
|
print(f"β Test {test.__name__} crashed: {e}") |
|
|
|
|
|
print("\n" + "=" * 40) |
|
|
print(f"Test Results: {passed}/{total} passed") |
|
|
|
|
|
if passed == total: |
|
|
print("π All tests passed! Extensions are working correctly.") |
|
|
return 0 |
|
|
else: |
|
|
print("β οΈ Some tests failed. Check the output above.") |
|
|
return 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit_code = main() |
|
|
sys.exit(exit_code) |