| | |
| | """ |
| | Full end-to-end BitTransformerLM training run with all optimizations! |
| | Small scale test to validate our enhanced system. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| | import numpy as np |
| | import logging |
| | from pathlib import Path |
| | import time |
| | from typing import List, Dict, Any |
| |
|
| | |
| | from bit_transformer.model import BitTransformerLM |
| | from bit_transformer.compression import compress_bits_batch, model_output_decompress |
| | from bit_transformer.error_handling import safe_model_forward, setup_error_logging |
| | from bit_transformer.types import BitSequence, TelemetryDict |
| | from enhanced_checkpoint_system import create_checkpoint_manager |
| |
|
| | |
| | logger = setup_error_logging("INFO") |
| |
|
| | class SimpleBitDataset(Dataset): |
| | """Simple dataset of bit sequences for training.""" |
| | |
| | def __init__(self, num_samples: int = 1000, seq_length: int = 128): |
| | self.num_samples = num_samples |
| | self.seq_length = seq_length |
| | self.data = self._generate_bit_sequences() |
| | |
| | def _generate_bit_sequences(self) -> List[torch.Tensor]: |
| | """Generate diverse bit sequences with different patterns.""" |
| | sequences = [] |
| | |
| | |
| | for i in range(self.num_samples // 4): |
| | pattern = torch.tensor([i % 2] * self.seq_length, dtype=torch.long) |
| | sequences.append(pattern) |
| | |
| | |
| | for i in range(self.num_samples // 4): |
| | pattern = torch.randint(0, 2, (self.seq_length,), dtype=torch.long) |
| | sequences.append(pattern) |
| | |
| | |
| | for i in range(self.num_samples // 4): |
| | pattern = [] |
| | pos = 0 |
| | while pos < self.seq_length: |
| | run_length = min(np.random.randint(1, 20), self.seq_length - pos) |
| | bit_value = np.random.randint(0, 2) |
| | pattern.extend([bit_value] * run_length) |
| | pos += run_length |
| | pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long) |
| | sequences.append(pattern) |
| | |
| | |
| | remaining = self.num_samples - len(sequences) |
| | for i in range(remaining): |
| | pattern = [0, 1] |
| | while len(pattern) < self.seq_length: |
| | pattern.append(pattern[-1] ^ pattern[-2]) |
| | pattern = torch.tensor(pattern[:self.seq_length], dtype=torch.long) |
| | sequences.append(pattern) |
| | |
| | return sequences |
| | |
| | def __len__(self): |
| | return len(self.data) |
| | |
| | def __getitem__(self, idx): |
| | sequence = self.data[idx] |
| | |
| | return sequence[:-1], sequence[1:] |
| |
|
| |
|
| | def compute_safety_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]: |
| | """Compute K/C/S safety metrics.""" |
| | pred_bits = (predictions > 0.5).float().flatten() |
| | |
| | |
| | if len(pred_bits) > 0: |
| | prob_1 = pred_bits.mean().item() |
| | prob_0 = 1 - prob_1 |
| | if prob_0 > 0 and prob_1 > 0: |
| | entropy = -prob_0 * np.log2(prob_0) - prob_1 * np.log2(prob_1) |
| | negentropy = 1.0 - entropy |
| | else: |
| | negentropy = 1.0 if prob_1 == 1.0 or prob_1 == 0.0 else 0.0 |
| | else: |
| | negentropy = 0.0 |
| | |
| | |
| | changes = (pred_bits[1:] != pred_bits[:-1]).sum().item() |
| | complexity = min(changes / len(pred_bits), 1.0) if len(pred_bits) > 1 else 0.0 |
| | |
| | |
| | target_bits = targets.float().flatten() |
| | if len(target_bits) > 0: |
| | target_mean = target_bits.mean() |
| | pred_mean = pred_bits.mean() |
| | symbiosis = 1.0 - abs(target_mean - pred_mean).item() |
| | else: |
| | symbiosis = 1.0 |
| | |
| | return { |
| | 'K_negentropy': negentropy, |
| | 'C_complexity': complexity, |
| | 'S_symbiosis': symbiosis |
| | } |
| |
|
| |
|
| | def train_bittransformer(): |
| | """Main training function with all optimizations.""" |
| | |
| | logger.info("π Starting BitTransformerLM end-to-end training run!") |
| | |
| | |
| | model_config = { |
| | 'd_model': 256, |
| | 'nhead': 8, |
| | 'num_layers': 4, |
| | 'dim_feedforward': 512, |
| | 'max_seq_len': 128, |
| | 'use_checkpoint': True, |
| | 'chunk_size': None, |
| | } |
| | |
| | training_config = { |
| | 'batch_size': 16, |
| | 'learning_rate': 1e-3, |
| | 'num_epochs': 10, |
| | 'save_every_n_epochs': 2, |
| | 'log_every_n_steps': 10 |
| | } |
| | |
| | |
| | checkpoint_manager = create_checkpoint_manager() |
| | session_id = checkpoint_manager.create_training_session( |
| | session_name="end_to_end_test", |
| | model_config=model_config, |
| | training_config=training_config |
| | ) |
| | |
| | logger.info(f"π Created training session: {session_id}") |
| | |
| | |
| | logger.info("π Creating training dataset...") |
| | dataset = SimpleBitDataset(num_samples=800, seq_length=model_config['max_seq_len']) |
| | dataloader = DataLoader(dataset, batch_size=training_config['batch_size'], shuffle=True) |
| | |
| | |
| | logger.info("π§ Initializing BitTransformerLM model...") |
| | model = BitTransformerLM( |
| | d_model=model_config['d_model'], |
| | nhead=model_config['nhead'], |
| | num_layers=model_config['num_layers'], |
| | dim_feedforward=model_config['dim_feedforward'], |
| | max_seq_len=model_config['max_seq_len'], |
| | use_checkpoint=model_config['use_checkpoint'], |
| | chunk_size=model_config['chunk_size'] |
| | ) |
| | |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | logger.info(f"π’ Model parameters: {total_params:,} total, {trainable_params:,} trainable") |
| | |
| | |
| | optimizer = optim.AdamW(model.parameters(), lr=training_config['learning_rate']) |
| | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_config['num_epochs']) |
| | criterion = nn.CrossEntropyLoss() |
| | |
| | |
| | logger.info("πββοΈ Starting training loop...") |
| | |
| | for epoch in range(training_config['num_epochs']): |
| | model.train() |
| | epoch_loss = 0.0 |
| | epoch_metrics = {'K_negentropy': 0.0, 'C_complexity': 0.0, 'S_symbiosis': 0.0} |
| | num_batches = 0 |
| | |
| | start_time = time.time() |
| | |
| | for batch_idx, (inputs, targets) in enumerate(dataloader): |
| | optimizer.zero_grad() |
| | |
| | |
| | try: |
| | |
| | output = safe_model_forward(model, inputs) |
| | if isinstance(output, tuple): |
| | logits, telemetry = output |
| | else: |
| | logits = output |
| | telemetry = {} |
| | |
| | |
| | |
| | if logits.dim() == 2: |
| | |
| | logits_flat = logits |
| | targets_flat = targets.reshape(-1) |
| | else: |
| | |
| | logits_flat = logits.reshape(-1, 2) |
| | targets_flat = targets.reshape(-1) |
| | |
| | loss = criterion(logits_flat, targets_flat) |
| | |
| | |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | |
| | |
| | with torch.no_grad(): |
| | |
| | if logits.dim() == 2: |
| | |
| | batch_size = inputs.shape[0] |
| | seq_len = inputs.shape[1] |
| | logits_reshaped = logits.reshape(batch_size, seq_len, 2) |
| | predictions = torch.softmax(logits_reshaped, dim=-1)[:, :, 1] |
| | else: |
| | |
| | predictions = torch.softmax(logits, dim=-1)[:, :, 1] |
| | |
| | safety_metrics = compute_safety_metrics(predictions, targets) |
| | |
| | epoch_loss += loss.item() |
| | for key, value in safety_metrics.items(): |
| | epoch_metrics[key] += value |
| | num_batches += 1 |
| | |
| | |
| | if batch_idx % training_config['log_every_n_steps'] == 0: |
| | logger.info(f"Epoch {epoch+1}/{training_config['num_epochs']}, " |
| | f"Batch {batch_idx}/{len(dataloader)}, " |
| | f"Loss: {loss.item():.4f}, " |
| | f"K: {safety_metrics['K_negentropy']:.3f}, " |
| | f"C: {safety_metrics['C_complexity']:.3f}, " |
| | f"S: {safety_metrics['S_symbiosis']:.3f}") |
| | |
| | except Exception as e: |
| | logger.error(f"Error in batch {batch_idx}: {e}") |
| | continue |
| | |
| | |
| | scheduler.step() |
| | epoch_time = time.time() - start_time |
| | |
| | if num_batches > 0: |
| | avg_loss = epoch_loss / num_batches |
| | avg_metrics = {k: v / num_batches for k, v in epoch_metrics.items()} |
| | |
| | logger.info(f"β
Epoch {epoch+1} completed in {epoch_time:.2f}s") |
| | logger.info(f"π Avg Loss: {avg_loss:.4f}") |
| | logger.info(f"π Safety Metrics - K: {avg_metrics['K_negentropy']:.3f}, " |
| | f"C: {avg_metrics['C_complexity']:.3f}, " |
| | f"S: {avg_metrics['S_symbiosis']:.3f}") |
| | |
| | |
| | if (epoch + 1) % training_config['save_every_n_epochs'] == 0: |
| | checkpoint_success = checkpoint_manager.save_checkpoint( |
| | model=model, |
| | session_id=session_id, |
| | epoch=epoch + 1, |
| | metrics={ |
| | 'loss': avg_loss, |
| | 'learning_rate': scheduler.get_last_lr()[0], |
| | **avg_metrics |
| | }, |
| | optimizer_state=optimizer.state_dict(), |
| | scheduler_state=scheduler.state_dict() |
| | ) |
| | |
| | if checkpoint_success: |
| | logger.info(f"πΎ Checkpoint saved for epoch {epoch+1}") |
| | |
| | |
| | checkpoint_manager.save_best_model( |
| | session_id=session_id, |
| | model=model, |
| | metric_name='loss', |
| | metric_value=avg_loss, |
| | is_better_func=lambda x, y: x < y |
| | ) |
| | |
| | logger.info("π Training completed successfully!") |
| | |
| | |
| | logger.info("π§ͺ Testing model inference and compression...") |
| | |
| | model.eval() |
| | with torch.no_grad(): |
| | |
| | test_input = torch.randint(0, 2, (1, 64), dtype=torch.long) |
| | logger.info(f"π₯ Input sequence: {test_input.squeeze().tolist()}") |
| | |
| | |
| | output_logits = model(test_input) |
| | output_probs = torch.softmax(output_logits, dim=-1) |
| | predicted_bits = torch.argmax(output_probs, dim=-1) |
| | |
| | logger.info(f"π€ Predicted sequence: {predicted_bits.squeeze().tolist()}") |
| | |
| | |
| | compressed = compress_bits_batch(predicted_bits) |
| | logger.info(f"ποΈ Compressed length: {len(compressed[0])} (original: {predicted_bits.shape[-1]})") |
| | |
| | |
| | decompressed = model_output_decompress(compressed) |
| | compression_success = torch.equal(predicted_bits, decompressed) |
| | logger.info(f"β
Compression/decompression successful: {compression_success}") |
| | |
| | |
| | storage_usage = checkpoint_manager.get_storage_usage() |
| | logger.info(f"πΎ Final storage usage: {storage_usage['total_gb']:.3f} GB") |
| | logger.info(f"π Training sessions: {storage_usage['num_sessions']}") |
| | |
| | return session_id, model, checkpoint_manager |
| |
|
| |
|
| | if __name__ == "__main__": |
| | try: |
| | session_id, trained_model, manager = train_bittransformer() |
| | print(f"\nπ SUCCESS! Training session completed: {session_id}") |
| | print(f"π Use checkpoint_manager.load_checkpoint('{session_id}') to resume") |
| | |
| | except Exception as e: |
| | logger.error(f"β Training failed: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | raise |