File size: 6,780 Bytes
5a75fec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
#!/usr/bin/env python3
"""
BREAKTHROUGH BitTransformerLM Training Script
===========================================
Using the ACTUAL BitTransformerLM model and training infrastructure,
configured for the Fixed RL Adafactor breakthrough results.
"""
import sys
import os
import logging
from pathlib import Path
import torch
from datasets import load_dataset
from huggingface_hub import login
# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')
from bit_transformer import (
BitTransformerLM,
text_to_bits,
train_loop,
save_model,
load_model,
set_dropout
)
from BTLM_Extensions import configure_adafactor_optimizer
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('breakthrough_training.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def load_and_prepare_dataset():
"""Load HF dataset and convert to bit tensors."""
logger.info("Loading WCNegentropy/BitTransformerLM dataset...")
# Login to HuggingFace
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
else:
print("Warning: HF_TOKEN environment variable not set")
# Load dataset
dataset = load_dataset("WCNegentropy/BitTransformerLM")
train_data = dataset['train']
logger.info(f"Dataset loaded: {len(train_data)} samples")
# Process dataset - the HF dataset already has bit_sequence field!
bit_sequences = []
for sample in train_data:
if 'bit_sequence' in sample and sample['bit_sequence'] is not None:
# The bit_sequence might already be a list
bits = sample['bit_sequence']
if isinstance(bits, str):
try:
bits = eval(bits) # Convert string representation to list
except:
bits = None
if isinstance(bits, list) and len(bits) > 0:
bit_sequences.append(bits)
else:
# Fallback: convert original_text to bits
text = sample.get('original_text', '')
if text:
bits = text_to_bits(text)
bit_sequences.append(bits)
else:
# Fallback: convert text to bits
text = sample.get('text', '') or sample.get('original_text', '')
if text:
bits = text_to_bits(text)
bit_sequences.append(bits)
logger.info(f"Processed {len(bit_sequences)} bit sequences")
# Create training tensors with proper sequence length
max_len = 512 # BitTransformerLM default max_seq_len
training_sequences = []
for bits in bit_sequences:
# Split long sequences into chunks
for i in range(0, len(bits) - max_len + 1, max_len // 2):
seq = bits[i:i + max_len]
if len(seq) == max_len: # Only use full-length sequences
training_sequences.append(seq)
# Convert to tensor
data_tensor = torch.tensor(training_sequences, dtype=torch.long)
logger.info(f"Created training tensor: {data_tensor.shape}")
return data_tensor
def create_breakthrough_model():
"""Create the EXACT breakthrough BitTransformerLM configuration."""
logger.info("Creating breakthrough BitTransformerLM model...")
# EXACT breakthrough configuration using ACTUAL BitTransformerLM parameters
model = BitTransformerLM(
d_model=512, # Breakthrough config
nhead=16, # 16 attention heads
num_layers=8, # 8 layers for ~16M params
dim_feedforward=1024, # 2x d_model
max_seq_len=512, # Match data preparation
reversible=True, # Memory efficiency
use_checkpoint=True, # Gradient checkpointing
use_autocast=True, # Mixed precision
use_act=True, # Adaptive Computation Time
act_threshold=0.9,
lambda_K=0.05, # Safety telemetry weights
lambda_C=0.05,
lambda_S=0.05
)
# Calculate parameter count
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"Model created: {total_params:,} parameters")
logger.info(f"Target: ~16M parameters - {'โ' if 15_000_000 <= total_params <= 17_000_000 else 'โ'}")
return model
def main():
"""Main training function."""
logger.info("๐ STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!")
logger.info("Using ACTUAL BitTransformerLM model and train_loop")
# Load dataset
data = load_and_prepare_dataset()
# Create model
model = create_breakthrough_model()
# CRITICAL: Use Fixed RL Adafactor (the breakthrough secret!)
logger.info("Configuring Fixed RL Adafactor optimizer...")
optimizer, scheduler = configure_adafactor_optimizer(
model,
lr=1e-3, # FIXED learning rate - key to breakthrough!
weight_decay=0.01,
total_steps=5000 # Estimated total steps
)
logger.info("Fixed RL Adafactor configured with LR=0.001")
# Training configuration
training_config = {
'epochs': 20, # Reasonable number of epochs
'batch_size': 4, # Adjust based on memory
'accum_steps': 4, # Gradient accumulation
'amp': True, # Mixed precision
'log': True, # Enable logging
'compress_prob': 0.0, # Start with no compression
'optimizer': optimizer,
'scheduler': scheduler
}
logger.info(f"Training configuration: {training_config}")
logger.info("Starting training loop...")
# Use the ACTUAL BitTransformerLM train_loop function!
metrics = train_loop(
model=model,
data=data,
**training_config
)
# Save the trained model
checkpoint_dir = Path('/data/BitTransformerLM/checkpoints')
checkpoint_dir.mkdir(exist_ok=True)
model_path = checkpoint_dir / 'breakthrough_model.pt'
save_model(model, model_path)
logger.info(f"Model saved to: {model_path}")
# Log final metrics
if metrics:
final_metrics = metrics[-1]
logger.info("๐ TRAINING COMPLETED!")
logger.info(f"Final raw_loss: {final_metrics['raw_loss']:.6f}")
logger.info(f"Final raw_acc: {final_metrics['raw_acc']:.3f}")
# Check for breakthrough performance
if final_metrics['raw_loss'] < 3.0:
logger.info("๐ BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!")
logger.info("Breakthrough training completed successfully!")
if __name__ == "__main__":
main() |