| | |
| | """ |
| | CPU-Optimized Edge Deployment BitTransformerLM Training |
| | Optimized for consumer devices and edge applications. |
| | """ |
| |
|
| | import os |
| | import time |
| | import torch |
| | import torch.nn.functional as F |
| | from datasets import load_dataset |
| |
|
| | from bit_transformer import ( |
| | BitTransformerLM, |
| | text_to_bits, |
| | bits_to_text, |
| | train_loop, |
| | configure_optimizer, |
| | save_model, |
| | load_model, |
| | set_dropout, |
| | hil_safe_inference, |
| | quantize_dynamic, |
| | ) |
| | from bit_transformer.torch_utils import cpu_autocast |
| | from bit_transformer.training import train_loop |
| |
|
| |
|
| | def create_optimal_cpu_model(): |
| | """Create BitTransformerLM optimized for CPU edge deployment.""" |
| | print("π§ Creating CPU-optimized BitTransformerLM...") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=64, |
| | nhead=4, |
| | num_layers=3, |
| | dim_feedforward=128, |
| | max_seq_len=256, |
| | reversible=False, |
| | use_checkpoint=False, |
| | use_autocast=True, |
| | use_act=False, |
| | chunk_size=32, |
| | full_attn_logging=False, |
| | lambda_K=1.0, |
| | lambda_C=1.0, |
| | lambda_S=1.0, |
| | ) |
| | |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | |
| | print(f" π Model Configuration:") |
| | print(f" d_model: {64}") |
| | print(f" num_layers: {3}") |
| | print(f" nhead: {4}") |
| | print(f" dim_feedforward: {128}") |
| | print(f" max_seq_len: {256}") |
| | print(f" Total parameters: {total_params:,}") |
| | print(f" Trainable parameters: {trainable_params:,}") |
| | print(f" Estimated size: {total_params * 4 / 1024 / 1024:.1f}MB (FP32)") |
| | print(f" With autocast: ~{total_params * 2 / 1024 / 1024:.1f}MB (BF16)") |
| | |
| | return model |
| |
|
| |
|
| | def load_training_dataset(dataset_size=512, max_len=128): |
| | """Load and prepare training dataset optimized for edge training.""" |
| | print("π Loading training dataset...") |
| | |
| | try: |
| | |
| | print(" Attempting to load BitTransformerLM dataset...") |
| | dataset = load_dataset("WCNegentropy/BitTransformerLM", split="train[:{}]".format(dataset_size)) |
| | if dataset and len(dataset) > 0: |
| | train_texts = [item['text'] for item in dataset if item.get('text')] |
| | if len(train_texts) > 0: |
| | print(f" β
Loaded {len(train_texts)} samples from BitTransformerLM dataset") |
| | else: |
| | raise Exception("No text samples found in dataset") |
| | else: |
| | raise Exception("Dataset empty or not accessible") |
| | |
| | except Exception as e: |
| | print(f" β οΈ BitTransformerLM dataset not available: {e}") |
| | print(" π Falling back to WikiText-2...") |
| | try: |
| | |
| | ds = load_dataset("wikitext", "wikitext-2-raw-v1") |
| | train_texts = [text for text in ds["train"]["text"] if text.strip()][:dataset_size] |
| | print(f" β
Loaded {len(train_texts)} samples from WikiText-2") |
| | except Exception as e2: |
| | print(f" β Failed to load WikiText-2: {e2}") |
| | print(" π² Using synthetic text data...") |
| | |
| | synthetic_texts = [ |
| | "The quick brown fox jumps over the lazy dog.", |
| | "Machine learning is transforming technology.", |
| | "Edge computing enables local AI processing.", |
| | "BitTransformerLM uses bit-native processing.", |
| | "CPU optimization improves inference speed.", |
| | "Neural networks learn from training data.", |
| | "Transformers use attention mechanisms.", |
| | "Language models understand text patterns.", |
| | ] |
| | train_texts = (synthetic_texts * (dataset_size // len(synthetic_texts) + 1))[:dataset_size] |
| | print(f" β
Generated {len(train_texts)} synthetic samples") |
| | |
| | |
| | print(" π Converting text to bits...") |
| | train_sequences = [] |
| | valid_sequences = [] |
| | |
| | for i, text in enumerate(train_texts): |
| | try: |
| | bits = text_to_bits(text)[:max_len] |
| | if len(bits) < max_len: |
| | bits.extend([0] * (max_len - len(bits))) |
| | |
| | |
| | if i < len(train_texts) * 0.8: |
| | train_sequences.append(bits) |
| | else: |
| | valid_sequences.append(bits) |
| | |
| | except Exception as e: |
| | print(f" β οΈ Failed to convert text to bits: {e}") |
| | continue |
| | |
| | train_tensor = torch.tensor(train_sequences, dtype=torch.long) |
| | valid_tensor = torch.tensor(valid_sequences, dtype=torch.long) if valid_sequences else train_tensor[:16] |
| | |
| | print(f" π Dataset Statistics:") |
| | print(f" Training sequences: {len(train_sequences)}") |
| | print(f" Validation sequences: {len(valid_sequences)}") |
| | print(f" Sequence length: {max_len}") |
| | print(f" Training tensor shape: {train_tensor.shape}") |
| | |
| | return train_tensor, valid_tensor, train_texts[:len(train_sequences)] |
| |
|
| |
|
| | def train_cpu_optimized_model(model, train_data, valid_data, epochs=5): |
| | """Train the model with CPU-optimized settings.""" |
| | print(f"π Training CPU-optimized BitTransformerLM for {epochs} epochs...") |
| | |
| | |
| | model.train() |
| | set_dropout(model, 0.1) |
| | |
| | |
| | |
| | batch_size = 4 |
| | learning_rate = 5e-4 |
| | total_steps = max(1, epochs * (len(train_data) // batch_size)) |
| | |
| | if len(train_data) == 0: |
| | raise ValueError("No training data available - check dataset loading") |
| | |
| | optimizer, scheduler = configure_optimizer( |
| | model, |
| | lr=learning_rate, |
| | total_steps=total_steps, |
| | weight_decay=0.01 |
| | ) |
| | |
| | print(f" π Training Configuration:") |
| | print(f" Batch size: {batch_size}") |
| | print(f" Learning rate: {learning_rate}") |
| | print(f" Total steps: {total_steps}") |
| | print(f" CPU autocast: Enabled") |
| | |
| | |
| | train_losses = [] |
| | |
| | for epoch in range(epochs): |
| | print(f"\n π Epoch {epoch + 1}/{epochs}") |
| | epoch_losses = [] |
| | epoch_start_time = time.time() |
| | |
| | |
| | perm = torch.randperm(len(train_data)) |
| | train_data_shuffled = train_data[perm] |
| | |
| | |
| | for batch_idx in range(0, len(train_data_shuffled), batch_size): |
| | batch_end = min(batch_idx + batch_size, len(train_data_shuffled)) |
| | batch = train_data_shuffled[batch_idx:batch_end] |
| | |
| | if len(batch) == 0: |
| | continue |
| | |
| | optimizer.zero_grad() |
| | |
| | |
| | with cpu_autocast(): |
| | logits, telemetry = model(batch) |
| | |
| | |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = batch[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| | |
| | |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | |
| | |
| | if scheduler.last_epoch < scheduler.total_steps - 1: |
| | scheduler.step() |
| | |
| | batch_loss = loss.item() |
| | epoch_losses.append(batch_loss) |
| | |
| | |
| | if (batch_idx // batch_size) % 50 == 0: |
| | avg_loss = sum(epoch_losses[-10:]) / min(10, len(epoch_losses)) |
| | telemetry_str = f"K={telemetry.get('K', 0):.3f}, C={telemetry.get('C', 0):.3f}, S={telemetry.get('S', 0):.3f}" |
| | print(f" Step {batch_idx // batch_size}: Loss={avg_loss:.4f}, {telemetry_str}") |
| | |
| | epoch_time = time.time() - epoch_start_time |
| | avg_epoch_loss = sum(epoch_losses) / len(epoch_losses) |
| | train_losses.append(avg_epoch_loss) |
| | |
| | print(f" β±οΈ Epoch {epoch + 1} completed in {epoch_time:.1f}s, Avg Loss: {avg_epoch_loss:.4f}") |
| | |
| | |
| | if len(valid_data) > 0: |
| | model.eval() |
| | set_dropout(model, 0.0) |
| | |
| | with torch.no_grad(): |
| | with cpu_autocast(): |
| | val_batch = valid_data[:min(8, len(valid_data))] |
| | val_logits, val_telemetry = model(val_batch) |
| | val_pred = val_logits[:, :-1, :].reshape(-1, 2) |
| | val_target = val_batch[:, 1:].reshape(-1) |
| | val_loss = F.cross_entropy(val_pred, val_target).item() |
| | |
| | print(f" π Validation Loss: {val_loss:.4f}") |
| | print(f" π Telemetry - K: {val_telemetry.get('K', 0):.3f}, C: {val_telemetry.get('C', 0):.3f}, S: {val_telemetry.get('S', 0):.3f}") |
| | |
| | model.train() |
| | set_dropout(model, 0.1) |
| | |
| | print(f"\nβ
Training completed!") |
| | print(f" Final training loss: {train_losses[-1]:.4f}") |
| | |
| | return model, train_losses |
| |
|
| |
|
| | def test_model_inference(model, test_texts): |
| | """Test the trained model with inference and safety checks.""" |
| | print("\nπ§ͺ Testing Model Inference...") |
| | |
| | model.eval() |
| | set_dropout(model, 0.0) |
| | |
| | |
| | test_samples = test_texts[:3] |
| | |
| | for i, text in enumerate(test_samples): |
| | print(f"\n Test {i + 1}: {text[:50]}...") |
| | |
| | try: |
| | |
| | input_bits = text_to_bits(text)[:64] |
| | if len(input_bits) < 64: |
| | input_bits.extend([0] * (64 - len(input_bits))) |
| | |
| | input_tensor = torch.tensor([input_bits], dtype=torch.long) |
| | |
| | |
| | with torch.no_grad(): |
| | with cpu_autocast(): |
| | logits, telemetry = model(input_tensor) |
| | |
| | |
| | next_token_logits = logits[0, -1, :] |
| | next_token_probs = F.softmax(next_token_logits, dim=-1) |
| | next_token = torch.multinomial(next_token_probs, 1).item() |
| | |
| | print(f" Input bits: {input_bits[:16]}... (showing first 16)") |
| | print(f" Next token prediction: {next_token}") |
| | print(f" Next token confidence: {next_token_probs[next_token]:.3f}") |
| | print(f" Telemetry - K: {telemetry.get('K', 0):.3f}, C: {telemetry.get('C', 0):.3f}, S: {telemetry.get('S', 0):.3f}") |
| | |
| | except Exception as e: |
| | print(f" β Inference failed: {e}") |
| | |
| | |
| | print(f"\nπ‘οΈ Testing Safe Inference...") |
| | try: |
| | |
| | test_prompt = "The future of AI is" |
| | prompt_bits = text_to_bits(test_prompt) |
| | prompt_tensor = torch.tensor([prompt_bits], dtype=torch.long) |
| | |
| | with cpu_autocast(): |
| | safe_result = hil_safe_inference(model, prompt_tensor, max_new_tokens=16) |
| | |
| | if safe_result is not None: |
| | print(f" β
Safe inference successful") |
| | print(f" Generated {len(safe_result[0]) - len(prompt_bits)} new tokens") |
| | else: |
| | print(f" β οΈ Safe inference blocked by safety gates") |
| | |
| | except Exception as e: |
| | print(f" β Safe inference test failed: {e}") |
| |
|
| |
|
| | def benchmark_cpu_performance(model): |
| | """Benchmark the model's CPU performance.""" |
| | print("\nβ‘ CPU Performance Benchmark...") |
| | |
| | model.eval() |
| | set_dropout(model, 0.0) |
| | |
| | |
| | batch_sizes = [1, 2, 4] |
| | sequence_lengths = [32, 64, 128] |
| | |
| | results = [] |
| | |
| | for batch_size in batch_sizes: |
| | for seq_len in sequence_lengths: |
| | print(f"\n Testing batch_size={batch_size}, seq_len={seq_len}") |
| | |
| | |
| | test_data = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long) |
| | |
| | |
| | with torch.no_grad(): |
| | with cpu_autocast(): |
| | for _ in range(3): |
| | _, _ = model(test_data) |
| | |
| | |
| | times = [] |
| | for _ in range(10): |
| | start_time = time.time() |
| | with torch.no_grad(): |
| | with cpu_autocast(): |
| | logits, telemetry = model(test_data) |
| | end_time = time.time() |
| | times.append(end_time - start_time) |
| | |
| | avg_time = sum(times) / len(times) |
| | throughput = (batch_size * seq_len) / avg_time |
| | |
| | result = { |
| | 'batch_size': batch_size, |
| | 'seq_len': seq_len, |
| | 'avg_time_ms': avg_time * 1000, |
| | 'throughput_tokens_per_sec': throughput |
| | } |
| | results.append(result) |
| | |
| | print(f" Average time: {avg_time * 1000:.2f}ms") |
| | print(f" Throughput: {throughput:.0f} tokens/sec") |
| | |
| | |
| | print(f"\nπ Performance Summary:") |
| | best_throughput = max(results, key=lambda x: x['throughput_tokens_per_sec']) |
| | print(f" Best throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec") |
| | print(f" At batch_size={best_throughput['batch_size']}, seq_len={best_throughput['seq_len']}") |
| | |
| | return results |
| |
|
| |
|
| | def quantize_for_deployment(model): |
| | """Apply dynamic quantization for deployment.""" |
| | print("\nποΈ Applying Dynamic Quantization for Deployment...") |
| | |
| | try: |
| | quantized_model = quantize_dynamic(model) |
| | |
| | |
| | original_params = sum(p.numel() for p in model.parameters()) |
| | quantized_params = sum(p.numel() for p in quantized_model.parameters()) |
| | |
| | print(f" Original parameters: {original_params:,}") |
| | print(f" Quantized parameters: {quantized_params:,}") |
| | print(f" Model size reduction: ~50% (FP32 -> INT8)") |
| | |
| | |
| | test_input = torch.randint(0, 2, (1, 32), dtype=torch.long) |
| | |
| | with torch.no_grad(): |
| | original_output = model(test_input) |
| | quantized_output = quantized_model(test_input) |
| | |
| | print(f" β
Quantization successful - model still functional") |
| | |
| | return quantized_model |
| | |
| | except Exception as e: |
| | print(f" β Quantization failed: {e}") |
| | return model |
| |
|
| |
|
| | def main(): |
| | """Main training and testing pipeline.""" |
| | print("π CPU-Optimized BitTransformerLM Training Pipeline") |
| | print("="*60) |
| | |
| | |
| | model = create_optimal_cpu_model() |
| | |
| | |
| | train_data, valid_data, train_texts = load_training_dataset(dataset_size=256, max_len=128) |
| | |
| | |
| | trained_model, train_losses = train_cpu_optimized_model(model, train_data, valid_data, epochs=3) |
| | |
| | |
| | test_model_inference(trained_model, train_texts) |
| | |
| | |
| | benchmark_results = benchmark_cpu_performance(trained_model) |
| | |
| | |
| | quantized_model = quantize_for_deployment(trained_model) |
| | |
| | |
| | print("\nπΎ Saving Models...") |
| | |
| | |
| | os.makedirs("weights", exist_ok=True) |
| | |
| | try: |
| | save_model(trained_model, "weights/cpu_edge_model.pt.gz") |
| | print(" β
Saved trained model: weights/cpu_edge_model.pt.gz") |
| | |
| | save_model(quantized_model, "weights/cpu_edge_model_quantized.pt.gz") |
| | print(" β
Saved quantized model: weights/cpu_edge_model_quantized.pt.gz") |
| | |
| | except Exception as e: |
| | print(f" β οΈ Model saving failed: {e}") |
| | |
| | |
| | print("\n" + "="*60) |
| | print("π CPU-Optimized BitTransformerLM Training Complete!") |
| | print("="*60) |
| | |
| | total_params = sum(p.numel() for p in trained_model.parameters()) |
| | final_loss = train_losses[-1] if train_losses else "N/A" |
| | best_throughput = max(benchmark_results, key=lambda x: x['throughput_tokens_per_sec']) |
| | |
| | print(f"π Final Results:") |
| | print(f" Model Parameters: {total_params:,}") |
| | print(f" Final Training Loss: {final_loss}") |
| | print(f" Peak Throughput: {best_throughput['throughput_tokens_per_sec']:.0f} tokens/sec") |
| | print(f" Model Size (quantized): ~{total_params * 1 / 1024 / 1024:.1f}MB") |
| | print(f" CPU Optimizations: BF16 autocast, no gradient checkpointing, small chunks") |
| | print(f" Edge Ready: β
Optimized for consumer CPUs") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |