| |
| """ |
| Test the trained BitTransformerLM model and validate all features. |
| """ |
|
|
| import torch |
| import numpy as np |
| import logging |
| from enhanced_checkpoint_system import create_checkpoint_manager |
| from bit_transformer.model import BitTransformerLM |
| from bit_transformer.compression import compress_bits_batch, model_output_decompress |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def test_trained_model(): |
| """Test the most recent trained model.""" |
| |
| print("π§ͺ Testing trained BitTransformerLM model...") |
| |
| |
| manager = create_checkpoint_manager() |
| |
| |
| sessions = list(manager.sessions_dir.iterdir()) |
| if not sessions: |
| print("β No training sessions found") |
| return |
| |
| latest_session = max(sessions, key=lambda x: x.stat().st_mtime) |
| session_id = latest_session.name |
| |
| print(f"π Loading from session: {session_id}") |
| |
| |
| model = BitTransformerLM( |
| d_model=256, |
| nhead=8, |
| num_layers=4, |
| dim_feedforward=512, |
| max_seq_len=128, |
| use_checkpoint=True, |
| chunk_size=None |
| ) |
| |
| |
| try: |
| checkpoint_data = manager.load_checkpoint(session_id, model=model) |
| print(f"β
Model loaded from: {checkpoint_data['checkpoint_path']}") |
| |
| metrics = checkpoint_data['model_data']['metrics'] |
| print(f"π Training metrics - Loss: {metrics['loss']:.4f}, " |
| f"K: {metrics['K_negentropy']:.3f}, " |
| f"C: {metrics['C_complexity']:.3f}, " |
| f"S: {metrics['S_symbiosis']:.3f}") |
| |
| except Exception as e: |
| print(f"β Failed to load checkpoint: {e}") |
| return |
| |
| |
| model.eval() |
| with torch.no_grad(): |
| print("\n㪠Testing model inference...") |
| |
| |
| test_input1 = torch.tensor([[0, 1, 0, 1, 0, 1, 0, 1]], dtype=torch.long) |
| output1 = model(test_input1) |
| |
| if isinstance(output1, tuple): |
| logits1, telemetry1 = output1 |
| print(f"β
Forward pass successful, output shape: {logits1.shape}") |
| print(f"π‘ Telemetry keys: {list(telemetry1.keys())}") |
| else: |
| logits1 = output1 |
| print(f"β
Forward pass successful, output shape: {logits1.shape}") |
| |
| |
| if logits1.dim() == 3: |
| predictions1 = torch.argmax(logits1, dim=-1) |
| else: |
| predictions1 = torch.argmax(logits1.reshape(1, 8, 2), dim=-1) |
| |
| print(f"π₯ Input: {test_input1.squeeze().tolist()}") |
| print(f"π€ Output: {predictions1.squeeze().tolist()}") |
| |
| |
| test_input2 = torch.randint(0, 2, (1, 16), dtype=torch.long) |
| output2 = model(test_input2) |
| |
| if isinstance(output2, tuple): |
| logits2, telemetry2 = output2 |
| else: |
| logits2 = output2 |
| |
| predictions2 = torch.argmax(logits2.reshape(1, 16, 2), dim=-1) |
| print(f"\nπ₯ Random input: {test_input2.squeeze().tolist()}") |
| print(f"π€ Model output: {predictions2.squeeze().tolist()}") |
| |
| |
| print("\nποΈ Testing compression features...") |
| |
| |
| long_sequence = torch.randint(0, 2, (1, 64), dtype=torch.long) |
| |
| |
| compressed = compress_bits_batch(long_sequence) |
| print(f"Original length: {long_sequence.shape[-1]}") |
| print(f"Compressed length: {len(compressed[0])}") |
| print(f"Compression ratio: {len(compressed[0]) / long_sequence.shape[-1]:.2f}") |
| |
| |
| decompressed = model_output_decompress(compressed) |
| compression_success = torch.equal(long_sequence, decompressed) |
| print(f"β
Compression/decompression successful: {compression_success}") |
| |
| |
| print("\nπ‘οΈ Testing safety metrics...") |
| |
| def compute_safety_metrics(predictions, targets): |
| pred_bits = predictions.float().flatten() |
| target_bits = targets.float().flatten() |
| |
| |
| prob_1 = pred_bits.mean().item() |
| prob_0 = 1 - prob_1 |
| if prob_0 > 0 and prob_1 > 0: |
| entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1) |
| negentropy = 1.0 - entropy |
| else: |
| negentropy = 1.0 |
| |
| |
| changes = (pred_bits[1:] != pred_bits[:-1]).sum().item() |
| complexity = changes / len(pred_bits) if len(pred_bits) > 1 else 0.0 |
| |
| |
| target_mean = target_bits.mean() |
| pred_mean = pred_bits.mean() |
| symbiosis = 1.0 - abs(target_mean - pred_mean).item() |
| |
| return { |
| 'K_negentropy': negentropy, |
| 'C_complexity': complexity, |
| 'S_symbiosis': symbiosis |
| } |
| |
| |
| test_patterns = [ |
| [0, 1, 0, 1, 0, 1, 0, 1], |
| [1, 1, 1, 1, 0, 0, 0, 0], |
| [0, 1, 1, 0, 1, 0, 1, 1], |
| ] |
| |
| for i, pattern in enumerate(test_patterns): |
| test_seq = torch.tensor([pattern], dtype=torch.long) |
| model_out = model(test_seq) |
| if isinstance(model_out, tuple): |
| model_logits, _ = model_out |
| else: |
| model_logits = model_out |
| |
| model_preds = torch.argmax(model_logits.reshape(1, len(pattern), 2), dim=-1) |
| metrics = compute_safety_metrics(model_preds, test_seq) |
| |
| print(f"Pattern {i+1}: K={metrics['K_negentropy']:.3f}, " |
| f"C={metrics['C_complexity']:.3f}, " |
| f"S={metrics['S_symbiosis']:.3f}") |
| |
| |
| print(f"\nπΎ Storage usage report:") |
| usage = manager.get_storage_usage() |
| print(f"Total storage used: {usage['total_gb']:.3f} GB") |
| print(f"Training sessions: {usage['num_sessions']}") |
| print(f"Best models saved: {usage['num_best_models']}") |
| |
| for session in usage['sessions'][:3]: |
| print(f" - {session['session_id']}: {session['size_gb']:.3f} GB " |
| f"({session['num_checkpoints']} checkpoints)") |
| |
| print("\nπ Model testing completed successfully!") |
| return True |
|
|
| if __name__ == "__main__": |
| success = test_trained_model() |
| if success: |
| print("β
ALL TESTS PASSED!") |
| else: |
| print("β Some tests failed") |