| | |
| | """ |
| | Test the trained BitTransformerLM model and validate all features. |
| | """ |
| |
|
| | import torch |
| | import numpy as np |
| | import logging |
| | from enhanced_checkpoint_system import create_checkpoint_manager |
| | from bit_transformer.model import BitTransformerLM |
| | from bit_transformer.compression import compress_bits_batch, model_output_decompress |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | def test_trained_model(): |
| | """Test the most recent trained model.""" |
| | |
| | print("π§ͺ Testing trained BitTransformerLM model...") |
| | |
| | |
| | manager = create_checkpoint_manager() |
| | |
| | |
| | sessions = list(manager.sessions_dir.iterdir()) |
| | if not sessions: |
| | print("β No training sessions found") |
| | return |
| | |
| | latest_session = max(sessions, key=lambda x: x.stat().st_mtime) |
| | session_id = latest_session.name |
| | |
| | print(f"π Loading from session: {session_id}") |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=256, |
| | nhead=8, |
| | num_layers=4, |
| | dim_feedforward=512, |
| | max_seq_len=128, |
| | use_checkpoint=True, |
| | chunk_size=None |
| | ) |
| | |
| | |
| | try: |
| | checkpoint_data = manager.load_checkpoint(session_id, model=model) |
| | print(f"β
Model loaded from: {checkpoint_data['checkpoint_path']}") |
| | |
| | metrics = checkpoint_data['model_data']['metrics'] |
| | print(f"π Training metrics - Loss: {metrics['loss']:.4f}, " |
| | f"K: {metrics['K_negentropy']:.3f}, " |
| | f"C: {metrics['C_complexity']:.3f}, " |
| | f"S: {metrics['S_symbiosis']:.3f}") |
| | |
| | except Exception as e: |
| | print(f"β Failed to load checkpoint: {e}") |
| | return |
| | |
| | |
| | model.eval() |
| | with torch.no_grad(): |
| | print("\n㪠Testing model inference...") |
| | |
| | |
| | test_input1 = torch.tensor([[0, 1, 0, 1, 0, 1, 0, 1]], dtype=torch.long) |
| | output1 = model(test_input1) |
| | |
| | if isinstance(output1, tuple): |
| | logits1, telemetry1 = output1 |
| | print(f"β
Forward pass successful, output shape: {logits1.shape}") |
| | print(f"π‘ Telemetry keys: {list(telemetry1.keys())}") |
| | else: |
| | logits1 = output1 |
| | print(f"β
Forward pass successful, output shape: {logits1.shape}") |
| | |
| | |
| | if logits1.dim() == 3: |
| | predictions1 = torch.argmax(logits1, dim=-1) |
| | else: |
| | predictions1 = torch.argmax(logits1.reshape(1, 8, 2), dim=-1) |
| | |
| | print(f"π₯ Input: {test_input1.squeeze().tolist()}") |
| | print(f"π€ Output: {predictions1.squeeze().tolist()}") |
| | |
| | |
| | test_input2 = torch.randint(0, 2, (1, 16), dtype=torch.long) |
| | output2 = model(test_input2) |
| | |
| | if isinstance(output2, tuple): |
| | logits2, telemetry2 = output2 |
| | else: |
| | logits2 = output2 |
| | |
| | predictions2 = torch.argmax(logits2.reshape(1, 16, 2), dim=-1) |
| | print(f"\nπ₯ Random input: {test_input2.squeeze().tolist()}") |
| | print(f"π€ Model output: {predictions2.squeeze().tolist()}") |
| | |
| | |
| | print("\nποΈ Testing compression features...") |
| | |
| | |
| | long_sequence = torch.randint(0, 2, (1, 64), dtype=torch.long) |
| | |
| | |
| | compressed = compress_bits_batch(long_sequence) |
| | print(f"Original length: {long_sequence.shape[-1]}") |
| | print(f"Compressed length: {len(compressed[0])}") |
| | print(f"Compression ratio: {len(compressed[0]) / long_sequence.shape[-1]:.2f}") |
| | |
| | |
| | decompressed = model_output_decompress(compressed) |
| | compression_success = torch.equal(long_sequence, decompressed) |
| | print(f"β
Compression/decompression successful: {compression_success}") |
| | |
| | |
| | print("\nπ‘οΈ Testing safety metrics...") |
| | |
| | def compute_safety_metrics(predictions, targets): |
| | pred_bits = predictions.float().flatten() |
| | target_bits = targets.float().flatten() |
| | |
| | |
| | prob_1 = pred_bits.mean().item() |
| | prob_0 = 1 - prob_1 |
| | if prob_0 > 0 and prob_1 > 0: |
| | entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1) |
| | negentropy = 1.0 - entropy |
| | else: |
| | negentropy = 1.0 |
| | |
| | |
| | changes = (pred_bits[1:] != pred_bits[:-1]).sum().item() |
| | complexity = changes / len(pred_bits) if len(pred_bits) > 1 else 0.0 |
| | |
| | |
| | target_mean = target_bits.mean() |
| | pred_mean = pred_bits.mean() |
| | symbiosis = 1.0 - abs(target_mean - pred_mean).item() |
| | |
| | return { |
| | 'K_negentropy': negentropy, |
| | 'C_complexity': complexity, |
| | 'S_symbiosis': symbiosis |
| | } |
| | |
| | |
| | test_patterns = [ |
| | [0, 1, 0, 1, 0, 1, 0, 1], |
| | [1, 1, 1, 1, 0, 0, 0, 0], |
| | [0, 1, 1, 0, 1, 0, 1, 1], |
| | ] |
| | |
| | for i, pattern in enumerate(test_patterns): |
| | test_seq = torch.tensor([pattern], dtype=torch.long) |
| | model_out = model(test_seq) |
| | if isinstance(model_out, tuple): |
| | model_logits, _ = model_out |
| | else: |
| | model_logits = model_out |
| | |
| | model_preds = torch.argmax(model_logits.reshape(1, len(pattern), 2), dim=-1) |
| | metrics = compute_safety_metrics(model_preds, test_seq) |
| | |
| | print(f"Pattern {i+1}: K={metrics['K_negentropy']:.3f}, " |
| | f"C={metrics['C_complexity']:.3f}, " |
| | f"S={metrics['S_symbiosis']:.3f}") |
| | |
| | |
| | print(f"\nπΎ Storage usage report:") |
| | usage = manager.get_storage_usage() |
| | print(f"Total storage used: {usage['total_gb']:.3f} GB") |
| | print(f"Training sessions: {usage['num_sessions']}") |
| | print(f"Best models saved: {usage['num_best_models']}") |
| | |
| | for session in usage['sessions'][:3]: |
| | print(f" - {session['session_id']}: {session['size_gb']:.3f} GB " |
| | f"({session['num_checkpoints']} checkpoints)") |
| | |
| | print("\nπ Model testing completed successfully!") |
| | return True |
| |
|
| | if __name__ == "__main__": |
| | success = test_trained_model() |
| | if success: |
| | print("β
ALL TESTS PASSED!") |
| | else: |
| | print("β Some tests failed") |