File size: 6,780 Bytes
5a75fec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
#!/usr/bin/env python3
"""
BREAKTHROUGH BitTransformerLM Training Script
===========================================

Using the ACTUAL BitTransformerLM model and training infrastructure,
configured for the Fixed RL Adafactor breakthrough results.
"""

import sys
import os
import logging
from pathlib import Path

import torch
from datasets import load_dataset
from huggingface_hub import login

# Add paths for imports
sys.path.append('/data')
sys.path.append('/data/BitTransformerLM')

from bit_transformer import (
    BitTransformerLM,
    text_to_bits,
    train_loop,
    save_model,
    load_model,
    set_dropout
)
from BTLM_Extensions import configure_adafactor_optimizer

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('breakthrough_training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def load_and_prepare_dataset():
    """Load HF dataset and convert to bit tensors."""
    logger.info("Loading WCNegentropy/BitTransformerLM dataset...")
    
    # Login to HuggingFace
    hf_token = os.getenv('HF_TOKEN')
    if hf_token:
        login(token=hf_token)
    else:
        print("Warning: HF_TOKEN environment variable not set")
    
    # Load dataset
    dataset = load_dataset("WCNegentropy/BitTransformerLM")
    train_data = dataset['train']
    
    logger.info(f"Dataset loaded: {len(train_data)} samples")
    
    # Process dataset - the HF dataset already has bit_sequence field!
    bit_sequences = []
    for sample in train_data:
        if 'bit_sequence' in sample and sample['bit_sequence'] is not None:
            # The bit_sequence might already be a list
            bits = sample['bit_sequence']
            if isinstance(bits, str):
                try:
                    bits = eval(bits)  # Convert string representation to list
                except:
                    bits = None
            if isinstance(bits, list) and len(bits) > 0:
                bit_sequences.append(bits)
            else:
                # Fallback: convert original_text to bits
                text = sample.get('original_text', '')
                if text:
                    bits = text_to_bits(text)
                    bit_sequences.append(bits)
        else:
            # Fallback: convert text to bits
            text = sample.get('text', '') or sample.get('original_text', '')
            if text:
                bits = text_to_bits(text)
                bit_sequences.append(bits)
    
    logger.info(f"Processed {len(bit_sequences)} bit sequences")
    
    # Create training tensors with proper sequence length
    max_len = 512  # BitTransformerLM default max_seq_len
    training_sequences = []
    
    for bits in bit_sequences:
        # Split long sequences into chunks
        for i in range(0, len(bits) - max_len + 1, max_len // 2):
            seq = bits[i:i + max_len]
            if len(seq) == max_len:  # Only use full-length sequences
                training_sequences.append(seq)
    
    # Convert to tensor
    data_tensor = torch.tensor(training_sequences, dtype=torch.long)
    logger.info(f"Created training tensor: {data_tensor.shape}")
    
    return data_tensor

def create_breakthrough_model():
    """Create the EXACT breakthrough BitTransformerLM configuration."""
    logger.info("Creating breakthrough BitTransformerLM model...")
    
    # EXACT breakthrough configuration using ACTUAL BitTransformerLM parameters
    model = BitTransformerLM(
        d_model=512,            # Breakthrough config
        nhead=16,              # 16 attention heads  
        num_layers=8,          # 8 layers for ~16M params
        dim_feedforward=1024,  # 2x d_model
        max_seq_len=512,       # Match data preparation
        reversible=True,       # Memory efficiency
        use_checkpoint=True,   # Gradient checkpointing
        use_autocast=True,     # Mixed precision
        use_act=True,          # Adaptive Computation Time
        act_threshold=0.9,
        lambda_K=0.05,         # Safety telemetry weights
        lambda_C=0.05,
        lambda_S=0.05
    )
    
    # Calculate parameter count
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Model created: {total_params:,} parameters")
    logger.info(f"Target: ~16M parameters - {'โœ“' if 15_000_000 <= total_params <= 17_000_000 else 'โœ—'}")
    
    return model

def main():
    """Main training function."""
    logger.info("๐Ÿš€ STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!")
    logger.info("Using ACTUAL BitTransformerLM model and train_loop")
    
    # Load dataset
    data = load_and_prepare_dataset()
    
    # Create model
    model = create_breakthrough_model()
    
    # CRITICAL: Use Fixed RL Adafactor (the breakthrough secret!)
    logger.info("Configuring Fixed RL Adafactor optimizer...")
    optimizer, scheduler = configure_adafactor_optimizer(
        model,
        lr=1e-3,  # FIXED learning rate - key to breakthrough!
        weight_decay=0.01,
        total_steps=5000  # Estimated total steps
    )
    logger.info("Fixed RL Adafactor configured with LR=0.001")
    
    # Training configuration
    training_config = {
        'epochs': 20,           # Reasonable number of epochs
        'batch_size': 4,        # Adjust based on memory
        'accum_steps': 4,       # Gradient accumulation  
        'amp': True,           # Mixed precision
        'log': True,           # Enable logging
        'compress_prob': 0.0,  # Start with no compression
        'optimizer': optimizer,
        'scheduler': scheduler
    }
    
    logger.info(f"Training configuration: {training_config}")
    logger.info("Starting training loop...")
    
    # Use the ACTUAL BitTransformerLM train_loop function!
    metrics = train_loop(
        model=model,
        data=data,
        **training_config
    )
    
    # Save the trained model
    checkpoint_dir = Path('/data/BitTransformerLM/checkpoints')
    checkpoint_dir.mkdir(exist_ok=True)
    
    model_path = checkpoint_dir / 'breakthrough_model.pt'
    save_model(model, model_path)
    logger.info(f"Model saved to: {model_path}")
    
    # Log final metrics
    if metrics:
        final_metrics = metrics[-1]
        logger.info("๐ŸŽ‰ TRAINING COMPLETED!")
        logger.info(f"Final raw_loss: {final_metrics['raw_loss']:.6f}")
        logger.info(f"Final raw_acc: {final_metrics['raw_acc']:.3f}")
        
        # Check for breakthrough performance
        if final_metrics['raw_loss'] < 3.0:
            logger.info("๐Ÿš€ BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!")
    
    logger.info("Breakthrough training completed successfully!")

if __name__ == "__main__":
    main()