|
|
|
|
|
""" |
|
|
BitTransformerLM Full Bi-Directional Attention Training Script |
|
|
=============================================================== |
|
|
|
|
|
This script implements the breakthrough Fixed RL Adafactor training configuration |
|
|
for production-scale BitTransformerLM training with FULL BI-DIRECTIONAL UNCHUNKED ATTENTION. |
|
|
|
|
|
Configuration: |
|
|
- Model: 16M parameters (d_model=512, nhead=16, num_layers=8) |
|
|
- Attention: FULL BI-DIRECTIONAL UNCHUNKED (chunk_size=None) |
|
|
- Optimizer: Fixed LR Adafactor (identical to breakthrough config) |
|
|
- Features: Reversible layers, ACT, QAT, compression |
|
|
- Data: HuggingFace WCNegentropy/BitTransformerLM dataset |
|
|
- Checkpointing: After every training cycle for continuous training |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import logging |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Optional, Dict, Any |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
sys.path.append('/data') |
|
|
sys.path.append('/data/BitTransformerLM') |
|
|
|
|
|
from bit_transformer import ( |
|
|
BitTransformerLM, |
|
|
text_to_bits, |
|
|
bits_to_text, |
|
|
save_model, |
|
|
load_model, |
|
|
set_dropout |
|
|
) |
|
|
from BTLM_Extensions import configure_adafactor_optimizer |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler('full_attention_training.log'), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ProductionTrainer: |
|
|
"""Production-grade BitTransformerLM trainer with breakthrough configuration.""" |
|
|
|
|
|
def __init__(self, config: Dict[str, Any]): |
|
|
self.config = config |
|
|
self.device = torch.device('cpu') |
|
|
self.model = None |
|
|
self.optimizer = None |
|
|
self.scheduler = None |
|
|
self.dataset = None |
|
|
self.checkpoint_dir = Path(config['checkpoint_dir']) |
|
|
self.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.current_epoch = 0 |
|
|
self.total_steps = 0 |
|
|
self.best_loss = float('inf') |
|
|
self.training_history = [] |
|
|
|
|
|
def setup_model(self): |
|
|
"""Create the breakthrough 16M parameter BitTransformerLM model with full bi-directional attention.""" |
|
|
logger.info("Setting up breakthrough BitTransformerLM with FULL BI-DIRECTIONAL UNCHUNKED ATTENTION...") |
|
|
|
|
|
self.model = BitTransformerLM( |
|
|
d_model=512, |
|
|
nhead=16, |
|
|
num_layers=8, |
|
|
dim_feedforward=1024, |
|
|
max_seq_len=512, |
|
|
reversible=True, |
|
|
use_checkpoint=True, |
|
|
use_autocast=True, |
|
|
use_act=True, |
|
|
act_threshold=0.9, |
|
|
lambda_K=0.05, |
|
|
lambda_C=0.05, |
|
|
lambda_S=0.05, |
|
|
chunk_size=None, |
|
|
overlap=0, |
|
|
full_attn_logging=True |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
total_params = sum(p.numel() for p in self.model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
|
|
|
|
|
logger.info(f"Model created: {total_params:,} total parameters ({trainable_params:,} trainable)") |
|
|
logger.info(f"Target: ~16M parameters - {'β' if 15_000_000 <= total_params <= 17_000_000 else 'β'}") |
|
|
|
|
|
return self.model |
|
|
|
|
|
def setup_optimizer(self): |
|
|
"""Setup Fixed RL Adafactor optimizer (the breakthrough secret sauce).""" |
|
|
logger.info("Setting up Fixed RL Adafactor optimizer...") |
|
|
|
|
|
|
|
|
self.optimizer, self.scheduler = configure_adafactor_optimizer( |
|
|
self.model, |
|
|
lr=self.config['learning_rate'], |
|
|
weight_decay=self.config['weight_decay'], |
|
|
total_steps=self.config['total_steps'] |
|
|
) |
|
|
|
|
|
logger.info(f"Fixed RL Adafactor configured with LR={self.config['learning_rate']}") |
|
|
return self.optimizer, self.scheduler |
|
|
|
|
|
def setup_dataset(self): |
|
|
"""Load and prepare the WCNegentropy/BitTransformerLM dataset.""" |
|
|
logger.info("Loading WCNegentropy/BitTransformerLM dataset...") |
|
|
|
|
|
|
|
|
login(token=self.config['hf_token']) |
|
|
|
|
|
|
|
|
try: |
|
|
dataset = load_dataset("WCNegentropy/BitTransformerLM") |
|
|
logger.info(f"Dataset loaded: {dataset}") |
|
|
|
|
|
|
|
|
train_data = dataset['train'] if 'train' in dataset else dataset |
|
|
logger.info(f"Training samples: {len(train_data)}") |
|
|
|
|
|
|
|
|
bit_sequences = [] |
|
|
for i, sample in enumerate(train_data): |
|
|
if i % 1000 == 0: |
|
|
logger.info(f"Processing sample {i}/{len(train_data)}") |
|
|
|
|
|
|
|
|
text = None |
|
|
if 'original_text' in sample and sample['original_text']: |
|
|
text = sample['original_text'] |
|
|
elif 'text' in sample and sample['text']: |
|
|
text = sample['text'] |
|
|
|
|
|
if text and text.strip(): |
|
|
|
|
|
bits = text_to_bits(text) |
|
|
if len(bits) >= self.config['sequence_length']: |
|
|
bit_sequences.append(bits) |
|
|
|
|
|
logger.info(f"Processed {len(bit_sequences)} valid bit sequences") |
|
|
|
|
|
|
|
|
seq_len = self.config['sequence_length'] |
|
|
training_sequences = [] |
|
|
|
|
|
for bits in bit_sequences: |
|
|
|
|
|
for i in range(0, len(bits) - seq_len + 1, seq_len // 2): |
|
|
chunk = bits[i:i + seq_len] |
|
|
if len(chunk) == seq_len: |
|
|
training_sequences.append(chunk) |
|
|
|
|
|
|
|
|
self.dataset = torch.tensor(training_sequences, dtype=torch.long) |
|
|
logger.info(f"Created training dataset: {self.dataset.shape}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset: {e}") |
|
|
|
|
|
logger.info("Falling back to synthetic bit data...") |
|
|
synthetic_bits = torch.randint(0, 2, (1000, self.config['sequence_length'])) |
|
|
self.dataset = synthetic_bits |
|
|
logger.warning("Using synthetic data - replace with real dataset for production") |
|
|
|
|
|
return self.dataset |
|
|
|
|
|
def save_checkpoint(self, epoch: int, loss: float, is_best: bool = False): |
|
|
"""Save model checkpoint with all training state.""" |
|
|
checkpoint_data = { |
|
|
'epoch': epoch, |
|
|
'total_steps': self.total_steps, |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, |
|
|
'loss': loss, |
|
|
'best_loss': self.best_loss, |
|
|
'config': self.config, |
|
|
'training_history': self.training_history, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
latest_path = self.checkpoint_dir / 'checkpoint_latest.pt' |
|
|
torch.save(checkpoint_data, latest_path) |
|
|
logger.info(f"Saved checkpoint: {latest_path}") |
|
|
|
|
|
|
|
|
epoch_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt' |
|
|
torch.save(checkpoint_data, epoch_path) |
|
|
|
|
|
|
|
|
if is_best: |
|
|
best_path = self.checkpoint_dir / 'checkpoint_best.pt' |
|
|
torch.save(checkpoint_data, best_path) |
|
|
logger.info(f"NEW BEST MODEL! Loss: {loss:.6f} -> {best_path}") |
|
|
|
|
|
|
|
|
config_path = self.checkpoint_dir / 'training_config.json' |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(self.config, f, indent=2) |
|
|
|
|
|
def load_checkpoint(self, checkpoint_path: Optional[str] = None) -> bool: |
|
|
"""Load model weights from latest checkpoint but restart training from epoch 1.""" |
|
|
if checkpoint_path is None: |
|
|
checkpoint_path = self.checkpoint_dir / 'checkpoint_latest.pt' |
|
|
|
|
|
checkpoint_path = Path(checkpoint_path) |
|
|
if not checkpoint_path.exists(): |
|
|
logger.info("No checkpoint found - starting fresh training") |
|
|
return False |
|
|
|
|
|
logger.info(f"Loading model weights from: {checkpoint_path}") |
|
|
try: |
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
|
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
|
|
|
self.current_epoch = 1 |
|
|
self.total_steps = 0 |
|
|
self.best_loss = float('inf') |
|
|
self.training_history = [] |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loaded model weights, restarting training from epoch 1, step 0") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load checkpoint: {e}") |
|
|
return False |
|
|
|
|
|
def training_step(self, batch: torch.Tensor) -> Dict[str, float]: |
|
|
"""Single training step with telemetry.""" |
|
|
self.model.train() |
|
|
set_dropout(self.model, self.config['dropout']) |
|
|
|
|
|
batch = batch.to(self.device) |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
with torch.autocast(device_type='cpu', dtype=torch.bfloat16): |
|
|
logits, telemetry = self.model(batch) |
|
|
|
|
|
|
|
|
if logits.dim() == 3: |
|
|
targets = batch[:, 1:] |
|
|
logits = logits[:, :-1] |
|
|
loss = F.cross_entropy(logits.reshape(-1, 2), targets.reshape(-1)) |
|
|
else: |
|
|
loss = F.cross_entropy(logits, batch) |
|
|
|
|
|
|
|
|
if self.model.lambda_K > 0 and 'negentropy_logits' in telemetry: |
|
|
k_term = self.model.lambda_K * (1 - telemetry['negentropy_logits']) |
|
|
if k_term.dim() == 0: |
|
|
loss = loss + k_term |
|
|
else: |
|
|
loss = loss + k_term.mean() |
|
|
if self.model.lambda_C > 0 and 'lz_complexity_logits' in telemetry: |
|
|
c_term = self.model.lambda_C * (1 - telemetry['lz_complexity_logits']) |
|
|
if c_term.dim() == 0: |
|
|
loss = loss + c_term |
|
|
else: |
|
|
loss = loss + c_term.mean() |
|
|
if self.model.lambda_S > 0 and 'symbiosis_score' in telemetry: |
|
|
s_term = self.model.lambda_S * (1 - telemetry['symbiosis_score']) |
|
|
if s_term.dim() == 0: |
|
|
loss = loss + s_term |
|
|
else: |
|
|
loss = loss + s_term.mean() |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm']) |
|
|
|
|
|
|
|
|
self.optimizer.step() |
|
|
if self.scheduler: |
|
|
self.scheduler.step() |
|
|
|
|
|
self.total_steps += 1 |
|
|
|
|
|
return { |
|
|
'loss': loss.item(), |
|
|
'K': telemetry.get('negentropy_logits', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('negentropy_logits', 0.0)) else telemetry.get('negentropy_logits', 0.0), |
|
|
'C': telemetry.get('lz_complexity_logits', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('lz_complexity_logits', 0.0)) else telemetry.get('lz_complexity_logits', 0.0), |
|
|
'S': telemetry.get('symbiosis_score', torch.tensor(0.0)).mean().item() if torch.is_tensor(telemetry.get('symbiosis_score', 0.0)) else telemetry.get('symbiosis_score', 0.0), |
|
|
'lr': self.optimizer.param_groups[0]['lr'] |
|
|
} |
|
|
|
|
|
def train_epoch(self) -> Dict[str, float]: |
|
|
"""Train for one epoch.""" |
|
|
logger.info(f"Starting epoch {self.current_epoch + 1}") |
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
dataloader = DataLoader( |
|
|
self.dataset, |
|
|
batch_size=self.config['batch_size'], |
|
|
shuffle=True, |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
epoch_losses = [] |
|
|
epoch_metrics = {'K': [], 'C': [], 'S': []} |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
for step, batch in enumerate(dataloader): |
|
|
metrics = self.training_step(batch) |
|
|
|
|
|
epoch_losses.append(metrics['loss']) |
|
|
epoch_metrics['K'].append(metrics['K']) |
|
|
epoch_metrics['C'].append(metrics['C']) |
|
|
epoch_metrics['S'].append(metrics['S']) |
|
|
|
|
|
|
|
|
if step % self.config['log_interval'] == 0: |
|
|
logger.info( |
|
|
f"Epoch {self.current_epoch + 1}, Step {step}/{len(dataloader)}: " |
|
|
f"Loss={metrics['loss']:.6f}, K={metrics['K']:.3f}, " |
|
|
f"C={metrics['C']:.3f}, S={metrics['S']:.3f}, LR={metrics['lr']:.2e}" |
|
|
) |
|
|
|
|
|
|
|
|
epoch_time = time.time() - start_time |
|
|
avg_loss = sum(epoch_losses) / len(epoch_losses) |
|
|
avg_metrics = {k: sum(v) / len(v) for k, v in epoch_metrics.items()} |
|
|
|
|
|
epoch_summary = { |
|
|
'epoch': self.current_epoch + 1, |
|
|
'avg_loss': avg_loss, |
|
|
'time': epoch_time, |
|
|
**avg_metrics |
|
|
} |
|
|
|
|
|
self.training_history.append(epoch_summary) |
|
|
|
|
|
logger.info( |
|
|
f"Epoch {self.current_epoch + 1} completed in {epoch_time:.1f}s: " |
|
|
f"Avg Loss={avg_loss:.6f}, K={avg_metrics['K']:.3f}, " |
|
|
f"C={avg_metrics['C']:.3f}, S={avg_metrics['S']:.3f}" |
|
|
) |
|
|
|
|
|
return epoch_summary |
|
|
|
|
|
def train(self, num_epochs: int): |
|
|
"""Main training loop.""" |
|
|
logger.info(f"Starting production training for {num_epochs} epochs...") |
|
|
logger.info(f"Breakthrough configuration: Fixed RL Adafactor + 16M BitTransformerLM") |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
try: |
|
|
|
|
|
epoch_metrics = self.train_epoch() |
|
|
avg_loss = epoch_metrics['avg_loss'] |
|
|
|
|
|
|
|
|
is_best = avg_loss < self.best_loss |
|
|
if is_best: |
|
|
self.best_loss = avg_loss |
|
|
|
|
|
|
|
|
self.save_checkpoint(self.current_epoch + 1, avg_loss, is_best) |
|
|
|
|
|
self.current_epoch += 1 |
|
|
|
|
|
|
|
|
logger.info(f"=== EPOCH {self.current_epoch} COMPLETE ===") |
|
|
logger.info(f"Loss: {avg_loss:.6f} (best: {self.best_loss:.6f})") |
|
|
|
|
|
|
|
|
if avg_loss < 3.0: |
|
|
logger.info("π BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
logger.info("Training interrupted by user") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Error in epoch {self.current_epoch + 1}: {e}") |
|
|
|
|
|
try: |
|
|
self.save_checkpoint(self.current_epoch, float('inf'), False) |
|
|
except: |
|
|
pass |
|
|
raise |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to run production training.""" |
|
|
|
|
|
|
|
|
config = { |
|
|
|
|
|
'model_params': { |
|
|
'd_model': 512, |
|
|
'nhead': 16, |
|
|
'num_layers': 8, |
|
|
'dim_feedforward': 1024, |
|
|
}, |
|
|
|
|
|
|
|
|
'learning_rate': 1e-3, |
|
|
'weight_decay': 0.01, |
|
|
'batch_size': 4, |
|
|
'sequence_length': 256, |
|
|
'num_epochs': 50, |
|
|
'max_grad_norm': 1.0, |
|
|
'dropout': 0.1, |
|
|
'total_steps': 10000, |
|
|
|
|
|
|
|
|
'hf_token': None, |
|
|
|
|
|
|
|
|
'log_interval': 10, |
|
|
'checkpoint_dir': '/data/BitTransformerLM/checkpoints', |
|
|
} |
|
|
|
|
|
|
|
|
trainer = ProductionTrainer(config) |
|
|
|
|
|
|
|
|
trainer.setup_model() |
|
|
trainer.setup_optimizer() |
|
|
trainer.setup_dataset() |
|
|
|
|
|
|
|
|
trainer.load_checkpoint() |
|
|
|
|
|
|
|
|
logger.info("π STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!") |
|
|
logger.info("Configuration: Fixed RL Adafactor + 16M parameters + CPU training") |
|
|
|
|
|
trainer.train(config['num_epochs']) |
|
|
|
|
|
logger.info("Training completed!") |
|
|
logger.info(f"Best loss achieved: {trainer.best_loss:.6f}") |
|
|
logger.info(f"Checkpoints saved to: {trainer.checkpoint_dir}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |