#!/usr/bin/env python3 """ Test script for BTLM_Extensions =============================== Quick test to verify all extensions are working properly. """ import sys import os import torch import torch.nn as nn # Add paths for imports sys.path.append('/data') sys.path.append('/data/BitTransformerLM') def test_imports(): """Test that all modules can be imported.""" print("Testing imports...") try: from BTLM_Extensions import ( Muon, Lion, Adafactor, configure_muon_optimizer, configure_lion_optimizer, configure_adafactor_optimizer, RLEEncoder, extension_manager, get_package_info ) print("✅ All imports successful") return True except Exception as e: print(f"❌ Import failed: {e}") return False def test_optimizers(): """Test optimizer functionality.""" print("\nTesting optimizers...") # Create a simple model model = nn.Sequential( nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 2) ) try: from BTLM_Extensions import ( configure_muon_optimizer, configure_lion_optimizer, configure_adafactor_optimizer ) # Test each optimizer optimizers_to_test = [ ("muon", configure_muon_optimizer, {"lr": 1e-3}), ("lion", configure_lion_optimizer, {"lr": 1e-4}), ("adafactor", configure_adafactor_optimizer, {"lr": 1e-3}), ] for name, config_fn, kwargs in optimizers_to_test: try: optimizer, scheduler = config_fn(model, total_steps=100, **kwargs) # Test a training step x = torch.randn(4, 10) y = torch.randint(0, 2, (4,)) pred = model(x) loss = nn.functional.cross_entropy(pred, y) loss.backward() optimizer.step() if scheduler: scheduler.step() optimizer.zero_grad() print(f"✅ {name.capitalize()} optimizer working") except Exception as e: print(f"❌ {name.capitalize()} optimizer failed: {e}") return True except Exception as e: print(f"❌ Optimizer test failed: {e}") return False def test_rle_compression(): """Test RLE compression.""" print("\nTesting RLE compression...") try: from BTLM_Extensions import RLEEncoder, benchmark_compression_schemes # Create test data with patterns test_data = torch.randint(0, 2, (50,)) # Add some runs for better compression test_data[10:20] = 1 test_data[30:40] = 0 # Test different schemes schemes = ["basic", "delta", "adaptive"] for scheme in schemes: try: encoder = RLEEncoder(scheme=scheme) compressed, metadata = encoder.encode(test_data) reconstructed = encoder.decode(compressed, metadata) # Check reconstruction error = torch.mean((test_data.float() - reconstructed.float()) ** 2) if error.item() < 1e-6: print(f"✅ RLE {scheme} scheme working (ratio: {metadata['compression_ratio']:.3f})") else: print(f"❌ RLE {scheme} scheme reconstruction error: {error.item()}") except Exception as e: print(f"❌ RLE {scheme} scheme failed: {e}") # Test benchmark function try: results = benchmark_compression_schemes(test_data) print(f"✅ RLE benchmark completed ({len(results)} schemes tested)") except Exception as e: print(f"❌ RLE benchmark failed: {e}") return True except Exception as e: print(f"❌ RLE compression test failed: {e}") return False def test_integration(): """Test integration features.""" print("\nTesting integration features...") try: from BTLM_Extensions import extension_manager, get_package_info # Test package info info = get_package_info() print(f"✅ Package info: {info['name']} v{info['version']}") # Test extension manager optimizers = extension_manager.SUPPORTED_OPTIMIZERS compression = extension_manager.SUPPORTED_COMPRESSION print(f"✅ Extension manager: {len(optimizers)} optimizers, {len(compression)} compression schemes") return True except Exception as e: print(f"❌ Integration test failed: {e}") return False def test_bittransformerlm_integration(): """Test integration with BitTransformerLM if available.""" print("\nTesting BitTransformerLM integration...") try: from bit_transformer import BitTransformerLM from BTLM_Extensions import configure_optimizer # Create a small BitTransformerLM model model = BitTransformerLM( d_model=64, nhead=4, num_layers=2, dim_feedforward=128, max_seq_len=32 ) # Test optimizer integration optimizer, scheduler = configure_optimizer("muon", model, lr=1e-3, total_steps=10) # Simple forward pass test_bits = torch.randint(0, 2, (2, 16)) logits, telemetry = model(test_bits) # Simple training step pred = logits[:, :-1, :].reshape(-1, 2) target = test_bits[:, 1:].reshape(-1) loss = nn.functional.cross_entropy(pred, target) loss.backward() optimizer.step() if scheduler: scheduler.step() print(f"✅ BitTransformerLM integration working (loss: {loss.item():.4f})") return True except ImportError: print("⚠️ BitTransformerLM not available, skipping integration test") return True except Exception as e: print(f"❌ BitTransformerLM integration failed: {e}") return False def main(): """Run all tests.""" print("BTLM_Extensions Test Suite") print("=" * 40) tests = [ test_imports, test_optimizers, test_rle_compression, test_integration, test_bittransformerlm_integration, ] passed = 0 total = len(tests) for test in tests: try: if test(): passed += 1 except Exception as e: print(f"❌ Test {test.__name__} crashed: {e}") print("\n" + "=" * 40) print(f"Test Results: {passed}/{total} passed") if passed == total: print("🎉 All tests passed! Extensions are working correctly.") return 0 else: print("⚠️ Some tests failed. Check the output above.") return 1 if __name__ == "__main__": exit_code = main() sys.exit(exit_code)