File size: 7,258 Bytes
2ca4b28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
#!/usr/bin/env python3
"""
Test script for BTLM_Extensions
===============================
Quick test to verify all extensions are working properly.
"""
import sys
import os
import torch
import torch.nn as nn
# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
def test_imports():
"""Test that all modules can be imported."""
print("Testing imports...")
try:
from BTLM_Extensions import (
Muon, Lion, Adafactor,
configure_muon_optimizer,
configure_lion_optimizer,
configure_adafactor_optimizer,
RLEEncoder,
extension_manager,
get_package_info
)
print("β
All imports successful")
return True
except Exception as e:
print(f"β Import failed: {e}")
return False
def test_optimizers():
"""Test optimizer functionality."""
print("\nTesting optimizers...")
# Create a simple model
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 2)
)
try:
from BTLM_Extensions import (
configure_muon_optimizer,
configure_lion_optimizer,
configure_adafactor_optimizer
)
# Test each optimizer
optimizers_to_test = [
("muon", configure_muon_optimizer, {"lr": 1e-3}),
("lion", configure_lion_optimizer, {"lr": 1e-4}),
("adafactor", configure_adafactor_optimizer, {"lr": 1e-3}),
]
for name, config_fn, kwargs in optimizers_to_test:
try:
optimizer, scheduler = config_fn(model, total_steps=100, **kwargs)
# Test a training step
x = torch.randn(4, 10)
y = torch.randint(0, 2, (4,))
pred = model(x)
loss = nn.functional.cross_entropy(pred, y)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
optimizer.zero_grad()
print(f"β
{name.capitalize()} optimizer working")
except Exception as e:
print(f"β {name.capitalize()} optimizer failed: {e}")
return True
except Exception as e:
print(f"β Optimizer test failed: {e}")
return False
def test_rle_compression():
"""Test RLE compression."""
print("\nTesting RLE compression...")
try:
from BTLM_Extensions import RLEEncoder, benchmark_compression_schemes
# Create test data with patterns
test_data = torch.randint(0, 2, (50,))
# Add some runs for better compression
test_data[10:20] = 1
test_data[30:40] = 0
# Test different schemes
schemes = ["basic", "delta", "adaptive"]
for scheme in schemes:
try:
encoder = RLEEncoder(scheme=scheme)
compressed, metadata = encoder.encode(test_data)
reconstructed = encoder.decode(compressed, metadata)
# Check reconstruction
error = torch.mean((test_data.float() - reconstructed.float()) ** 2)
if error.item() < 1e-6:
print(f"β
RLE {scheme} scheme working (ratio: {metadata['compression_ratio']:.3f})")
else:
print(f"β RLE {scheme} scheme reconstruction error: {error.item()}")
except Exception as e:
print(f"β RLE {scheme} scheme failed: {e}")
# Test benchmark function
try:
results = benchmark_compression_schemes(test_data)
print(f"β
RLE benchmark completed ({len(results)} schemes tested)")
except Exception as e:
print(f"β RLE benchmark failed: {e}")
return True
except Exception as e:
print(f"β RLE compression test failed: {e}")
return False
def test_integration():
"""Test integration features."""
print("\nTesting integration features...")
try:
from BTLM_Extensions import extension_manager, get_package_info
# Test package info
info = get_package_info()
print(f"β
Package info: {info['name']} v{info['version']}")
# Test extension manager
optimizers = extension_manager.SUPPORTED_OPTIMIZERS
compression = extension_manager.SUPPORTED_COMPRESSION
print(f"β
Extension manager: {len(optimizers)} optimizers, {len(compression)} compression schemes")
return True
except Exception as e:
print(f"β Integration test failed: {e}")
return False
def test_bittransformerlm_integration():
"""Test integration with BitTransformerLM if available."""
print("\nTesting BitTransformerLM integration...")
try:
from bit_transformer import BitTransformerLM
from BTLM_Extensions import configure_optimizer
# Create a small BitTransformerLM model
model = BitTransformerLM(
d_model=64,
nhead=4,
num_layers=2,
dim_feedforward=128,
max_seq_len=32
)
# Test optimizer integration
optimizer, scheduler = configure_optimizer("muon", model, lr=1e-3, total_steps=10)
# Simple forward pass
test_bits = torch.randint(0, 2, (2, 16))
logits, telemetry = model(test_bits)
# Simple training step
pred = logits[:, :-1, :].reshape(-1, 2)
target = test_bits[:, 1:].reshape(-1)
loss = nn.functional.cross_entropy(pred, target)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
print(f"β
BitTransformerLM integration working (loss: {loss.item():.4f})")
return True
except ImportError:
print("β οΈ BitTransformerLM not available, skipping integration test")
return True
except Exception as e:
print(f"β BitTransformerLM integration failed: {e}")
return False
def main():
"""Run all tests."""
print("BTLM_Extensions Test Suite")
print("=" * 40)
tests = [
test_imports,
test_optimizers,
test_rle_compression,
test_integration,
test_bittransformerlm_integration,
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
except Exception as e:
print(f"β Test {test.__name__} crashed: {e}")
print("\n" + "=" * 40)
print(f"Test Results: {passed}/{total} passed")
if passed == total:
print("π All tests passed! Extensions are working correctly.")
return 0
else:
print("β οΈ Some tests failed. Check the output above.")
return 1
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code) |