File size: 7,258 Bytes
2ca4b28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#!/usr/bin/env python3
"""
Test script for BTLM_Extensions
===============================

Quick test to verify all extensions are working properly.
"""

import sys
import os
import torch
import torch.nn as nn

# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')

def test_imports():
    """Test that all modules can be imported."""
    print("Testing imports...")
    
    try:
        from BTLM_Extensions import (
            Muon, Lion, Adafactor,
            configure_muon_optimizer,
            configure_lion_optimizer, 
            configure_adafactor_optimizer,
            RLEEncoder,
            extension_manager,
            get_package_info
        )
        print("βœ… All imports successful")
        return True
    except Exception as e:
        print(f"❌ Import failed: {e}")
        return False

def test_optimizers():
    """Test optimizer functionality."""
    print("\nTesting optimizers...")
    
    # Create a simple model
    model = nn.Sequential(
        nn.Linear(10, 20),
        nn.ReLU(),
        nn.Linear(20, 2)
    )
    
    try:
        from BTLM_Extensions import (
            configure_muon_optimizer,
            configure_lion_optimizer,
            configure_adafactor_optimizer
        )
        
        # Test each optimizer
        optimizers_to_test = [
            ("muon", configure_muon_optimizer, {"lr": 1e-3}),
            ("lion", configure_lion_optimizer, {"lr": 1e-4}), 
            ("adafactor", configure_adafactor_optimizer, {"lr": 1e-3}),
        ]
        
        for name, config_fn, kwargs in optimizers_to_test:
            try:
                optimizer, scheduler = config_fn(model, total_steps=100, **kwargs)
                
                # Test a training step
                x = torch.randn(4, 10)
                y = torch.randint(0, 2, (4,))
                
                pred = model(x)
                loss = nn.functional.cross_entropy(pred, y)
                loss.backward()
                
                optimizer.step()
                if scheduler:
                    scheduler.step()
                optimizer.zero_grad()
                
                print(f"βœ… {name.capitalize()} optimizer working")
                
            except Exception as e:
                print(f"❌ {name.capitalize()} optimizer failed: {e}")
        
        return True
        
    except Exception as e:
        print(f"❌ Optimizer test failed: {e}")
        return False

def test_rle_compression():
    """Test RLE compression."""
    print("\nTesting RLE compression...")
    
    try:
        from BTLM_Extensions import RLEEncoder, benchmark_compression_schemes
        
        # Create test data with patterns
        test_data = torch.randint(0, 2, (50,))
        # Add some runs for better compression
        test_data[10:20] = 1
        test_data[30:40] = 0
        
        # Test different schemes
        schemes = ["basic", "delta", "adaptive"]
        
        for scheme in schemes:
            try:
                encoder = RLEEncoder(scheme=scheme)
                compressed, metadata = encoder.encode(test_data)
                reconstructed = encoder.decode(compressed, metadata)
                
                # Check reconstruction
                error = torch.mean((test_data.float() - reconstructed.float()) ** 2)
                
                if error.item() < 1e-6:
                    print(f"βœ… RLE {scheme} scheme working (ratio: {metadata['compression_ratio']:.3f})")
                else:
                    print(f"❌ RLE {scheme} scheme reconstruction error: {error.item()}")
                    
            except Exception as e:
                print(f"❌ RLE {scheme} scheme failed: {e}")
        
        # Test benchmark function
        try:
            results = benchmark_compression_schemes(test_data)
            print(f"βœ… RLE benchmark completed ({len(results)} schemes tested)")
        except Exception as e:
            print(f"❌ RLE benchmark failed: {e}")
        
        return True
        
    except Exception as e:
        print(f"❌ RLE compression test failed: {e}")
        return False

def test_integration():
    """Test integration features."""
    print("\nTesting integration features...")
    
    try:
        from BTLM_Extensions import extension_manager, get_package_info
        
        # Test package info
        info = get_package_info()
        print(f"βœ… Package info: {info['name']} v{info['version']}")
        
        # Test extension manager
        optimizers = extension_manager.SUPPORTED_OPTIMIZERS
        compression = extension_manager.SUPPORTED_COMPRESSION
        print(f"βœ… Extension manager: {len(optimizers)} optimizers, {len(compression)} compression schemes")
        
        return True
        
    except Exception as e:
        print(f"❌ Integration test failed: {e}")
        return False

def test_bittransformerlm_integration():
    """Test integration with BitTransformerLM if available."""
    print("\nTesting BitTransformerLM integration...")
    
    try:
        from bit_transformer import BitTransformerLM
        from BTLM_Extensions import configure_optimizer
        
        # Create a small BitTransformerLM model
        model = BitTransformerLM(
            d_model=64,
            nhead=4, 
            num_layers=2,
            dim_feedforward=128,
            max_seq_len=32
        )
        
        # Test optimizer integration
        optimizer, scheduler = configure_optimizer("muon", model, lr=1e-3, total_steps=10)
        
        # Simple forward pass
        test_bits = torch.randint(0, 2, (2, 16))
        logits, telemetry = model(test_bits)
        
        # Simple training step
        pred = logits[:, :-1, :].reshape(-1, 2)
        target = test_bits[:, 1:].reshape(-1)
        loss = nn.functional.cross_entropy(pred, target)
        
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()
        
        print(f"βœ… BitTransformerLM integration working (loss: {loss.item():.4f})")
        return True
        
    except ImportError:
        print("⚠️  BitTransformerLM not available, skipping integration test")
        return True
    except Exception as e:
        print(f"❌ BitTransformerLM integration failed: {e}")
        return False

def main():
    """Run all tests."""
    print("BTLM_Extensions Test Suite")
    print("=" * 40)
    
    tests = [
        test_imports,
        test_optimizers, 
        test_rle_compression,
        test_integration,
        test_bittransformerlm_integration,
    ]
    
    passed = 0
    total = len(tests)
    
    for test in tests:
        try:
            if test():
                passed += 1
        except Exception as e:
            print(f"❌ Test {test.__name__} crashed: {e}")
    
    print("\n" + "=" * 40)
    print(f"Test Results: {passed}/{total} passed")
    
    if passed == total:
        print("πŸŽ‰ All tests passed! Extensions are working correctly.")
        return 0
    else:
        print("⚠️  Some tests failed. Check the output above.")
        return 1

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)