WCNegentropy commited on
Commit
b2e8740
·
verified ·
1 Parent(s): 8601a92

Remove test_trained_model.py - cleanup for OS launch

Browse files
Files changed (1) hide show
  1. test_trained_model.py +0 -188
test_trained_model.py DELETED
@@ -1,188 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test the trained BitTransformerLM model and validate all features.
4
- """
5
-
6
- import torch
7
- import numpy as np
8
- import logging
9
- from enhanced_checkpoint_system import create_checkpoint_manager
10
- from bit_transformer.model import BitTransformerLM
11
- from bit_transformer.compression import compress_bits_batch, model_output_decompress
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- def test_trained_model():
16
- """Test the most recent trained model."""
17
-
18
- print("🧪 Testing trained BitTransformerLM model...")
19
-
20
- # Load checkpoint manager
21
- manager = create_checkpoint_manager()
22
-
23
- # Find the most recent session
24
- sessions = list(manager.sessions_dir.iterdir())
25
- if not sessions:
26
- print("❌ No training sessions found")
27
- return
28
-
29
- latest_session = max(sessions, key=lambda x: x.stat().st_mtime)
30
- session_id = latest_session.name
31
-
32
- print(f"📁 Loading from session: {session_id}")
33
-
34
- # Initialize model with same config
35
- model = BitTransformerLM(
36
- d_model=256,
37
- nhead=8,
38
- num_layers=4,
39
- dim_feedforward=512,
40
- max_seq_len=128,
41
- use_checkpoint=True,
42
- chunk_size=None
43
- )
44
-
45
- # Load checkpoint
46
- try:
47
- checkpoint_data = manager.load_checkpoint(session_id, model=model)
48
- print(f"✅ Model loaded from: {checkpoint_data['checkpoint_path']}")
49
-
50
- metrics = checkpoint_data['model_data']['metrics']
51
- print(f"📊 Training metrics - Loss: {metrics['loss']:.4f}, "
52
- f"K: {metrics['K_negentropy']:.3f}, "
53
- f"C: {metrics['C_complexity']:.3f}, "
54
- f"S: {metrics['S_symbiosis']:.3f}")
55
-
56
- except Exception as e:
57
- print(f"❌ Failed to load checkpoint: {e}")
58
- return
59
-
60
- # Test inference
61
- model.eval()
62
- with torch.no_grad():
63
- print("\n🔬 Testing model inference...")
64
-
65
- # Test 1: Simple alternating pattern
66
- test_input1 = torch.tensor([[0, 1, 0, 1, 0, 1, 0, 1]], dtype=torch.long)
67
- output1 = model(test_input1)
68
-
69
- if isinstance(output1, tuple):
70
- logits1, telemetry1 = output1
71
- print(f"✅ Forward pass successful, output shape: {logits1.shape}")
72
- print(f"📡 Telemetry keys: {list(telemetry1.keys())}")
73
- else:
74
- logits1 = output1
75
- print(f"✅ Forward pass successful, output shape: {logits1.shape}")
76
-
77
- # Get predictions
78
- if logits1.dim() == 3:
79
- predictions1 = torch.argmax(logits1, dim=-1)
80
- else:
81
- predictions1 = torch.argmax(logits1.reshape(1, 8, 2), dim=-1)
82
-
83
- print(f"📥 Input: {test_input1.squeeze().tolist()}")
84
- print(f"📤 Output: {predictions1.squeeze().tolist()}")
85
-
86
- # Test 2: Random pattern
87
- test_input2 = torch.randint(0, 2, (1, 16), dtype=torch.long)
88
- output2 = model(test_input2)
89
-
90
- if isinstance(output2, tuple):
91
- logits2, telemetry2 = output2
92
- else:
93
- logits2 = output2
94
-
95
- predictions2 = torch.argmax(logits2.reshape(1, 16, 2), dim=-1)
96
- print(f"\n📥 Random input: {test_input2.squeeze().tolist()}")
97
- print(f"📤 Model output: {predictions2.squeeze().tolist()}")
98
-
99
- # Test 3: Compression/Decompression
100
- print("\n🗜️ Testing compression features...")
101
-
102
- # Create a longer sequence for compression testing
103
- long_sequence = torch.randint(0, 2, (1, 64), dtype=torch.long)
104
-
105
- # Test compression
106
- compressed = compress_bits_batch(long_sequence)
107
- print(f"Original length: {long_sequence.shape[-1]}")
108
- print(f"Compressed length: {len(compressed[0])}")
109
- print(f"Compression ratio: {len(compressed[0]) / long_sequence.shape[-1]:.2f}")
110
-
111
- # Test decompression
112
- decompressed = model_output_decompress(compressed)
113
- compression_success = torch.equal(long_sequence, decompressed)
114
- print(f"✅ Compression/decompression successful: {compression_success}")
115
-
116
- # Test 4: Safety metrics computation
117
- print("\n🛡️ Testing safety metrics...")
118
-
119
- def compute_safety_metrics(predictions, targets):
120
- pred_bits = predictions.float().flatten()
121
- target_bits = targets.float().flatten()
122
-
123
- # K metric (Negentropy)
124
- prob_1 = pred_bits.mean().item()
125
- prob_0 = 1 - prob_1
126
- if prob_0 > 0 and prob_1 > 0:
127
- entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1)
128
- negentropy = 1.0 - entropy
129
- else:
130
- negentropy = 1.0
131
-
132
- # C metric (Complexity)
133
- changes = (pred_bits[1:] != pred_bits[:-1]).sum().item()
134
- complexity = changes / len(pred_bits) if len(pred_bits) > 1 else 0.0
135
-
136
- # S metric (Symbiosis)
137
- target_mean = target_bits.mean()
138
- pred_mean = pred_bits.mean()
139
- symbiosis = 1.0 - abs(target_mean - pred_mean).item()
140
-
141
- return {
142
- 'K_negentropy': negentropy,
143
- 'C_complexity': complexity,
144
- 'S_symbiosis': symbiosis
145
- }
146
-
147
- # Test on several patterns
148
- test_patterns = [
149
- [0, 1, 0, 1, 0, 1, 0, 1], # Alternating
150
- [1, 1, 1, 1, 0, 0, 0, 0], # Block pattern
151
- [0, 1, 1, 0, 1, 0, 1, 1], # Mixed
152
- ]
153
-
154
- for i, pattern in enumerate(test_patterns):
155
- test_seq = torch.tensor([pattern], dtype=torch.long)
156
- model_out = model(test_seq)
157
- if isinstance(model_out, tuple):
158
- model_logits, _ = model_out
159
- else:
160
- model_logits = model_out
161
-
162
- model_preds = torch.argmax(model_logits.reshape(1, len(pattern), 2), dim=-1)
163
- metrics = compute_safety_metrics(model_preds, test_seq)
164
-
165
- print(f"Pattern {i+1}: K={metrics['K_negentropy']:.3f}, "
166
- f"C={metrics['C_complexity']:.3f}, "
167
- f"S={metrics['S_symbiosis']:.3f}")
168
-
169
- # Storage usage report
170
- print(f"\n💾 Storage usage report:")
171
- usage = manager.get_storage_usage()
172
- print(f"Total storage used: {usage['total_gb']:.3f} GB")
173
- print(f"Training sessions: {usage['num_sessions']}")
174
- print(f"Best models saved: {usage['num_best_models']}")
175
-
176
- for session in usage['sessions'][:3]: # Top 3 sessions by size
177
- print(f" - {session['session_id']}: {session['size_gb']:.3f} GB "
178
- f"({session['num_checkpoints']} checkpoints)")
179
-
180
- print("\n🎉 Model testing completed successfully!")
181
- return True
182
-
183
- if __name__ == "__main__":
184
- success = test_trained_model()
185
- if success:
186
- print("✅ ALL TESTS PASSED!")
187
- else:
188
- print("❌ Some tests failed")