WCNegentropy commited on
Commit
7123635
·
verified ·
1 Parent(s): 8ef2120

Remove massive_scale_training.py - cleanup for OS launch

Browse files
Files changed (1) hide show
  1. massive_scale_training.py +0 -590
massive_scale_training.py DELETED
@@ -1,590 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- BitTransformerLM Massive Scale Training Script
4
- ==============================================
5
-
6
- Scale BitTransformerLM to 1.21 BILLION parameters on extensive real corpus data.
7
- This script configures distributed training across 4x NVIDIA L4 GPUs with FSDP.
8
-
9
- Target Configuration:
10
- - Parameters: 1,208,164,352 (1.21B)
11
- - Architecture: d_model=2048, layers=24, heads=32, ff=8192
12
- - Dataset: WikiText-103 + additional real corpus data
13
- - Hardware: 4x NVIDIA L4 (23GB each), 181GB RAM, 48 CPU cores
14
- """
15
-
16
- import os
17
- import sys
18
- import time
19
- import math
20
- import json
21
- import logging
22
- import argparse
23
- from datetime import datetime
24
- from typing import Dict, Any, Optional, List, Tuple
25
- import warnings
26
-
27
- import torch
28
- import torch.nn as nn
29
- import torch.distributed as dist
30
- import torch.multiprocessing as mp
31
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
32
- from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch
33
- from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
34
- import torch.nn.functional as F
35
- from torch.utils.data import DataLoader, DistributedSampler
36
- import datasets
37
- from datasets import load_dataset
38
- import numpy as np
39
-
40
- # BitTransformerLM imports
41
- from bit_transformer.model import BitTransformerLM, LoggingTransformerEncoderLayer
42
- from bit_transformer.bit_io import text_to_bits, bits_to_text
43
- from bit_transformer.utils import set_dropout
44
- from bit_transformer.torch_utils import cpu_autocast
45
-
46
- # Configure logging
47
- logging.basicConfig(
48
- level=logging.INFO,
49
- format='%(asctime)s [%(levelname)s] %(message)s',
50
- handlers=[
51
- logging.FileHandler('/data/massive_scale_training.log'),
52
- logging.StreamHandler(sys.stdout)
53
- ]
54
- )
55
- logger = logging.getLogger(__name__)
56
-
57
- # Suppress warnings for cleaner output
58
- warnings.filterwarnings('ignore', category=UserWarning)
59
-
60
-
61
- class MassiveScaleConfig:
62
- """Configuration for 680M parameter BitTransformerLM training - GPU optimized for 4x L4."""
63
-
64
- # Model Architecture (680M parameters - GPU-optimized)
65
- D_MODEL = 1536
66
- NUM_LAYERS = 24
67
- NUM_HEADS = 24
68
- DIM_FEEDFORWARD = 6144
69
- MAX_SEQ_LEN = 2048
70
-
71
- # Training Configuration
72
- BATCH_SIZE_PER_GPU = 4 # Increased for 680M parameter model
73
- GRADIENT_ACCUMULATION_STEPS = 32
74
- EFFECTIVE_BATCH_SIZE = BATCH_SIZE_PER_GPU * 4 * GRADIENT_ACCUMULATION_STEPS # 512
75
-
76
- LEARNING_RATE = 6e-5 # Scaled for large model
77
- WEIGHT_DECAY = 0.1
78
- MAX_STEPS = 50000
79
- WARMUP_STEPS = 2000
80
-
81
- # Safety & Telemetry
82
- LAMBDA_K = 1.0
83
- LAMBDA_C = 1.0
84
- LAMBDA_S = 1.0
85
- NEGENTROPY_THRESHOLD = 0.15
86
- LZ_COMPLEXITY_THRESHOLD = 0.25
87
- SYMBIOSIS_THRESHOLD = 0.4
88
-
89
- # Optimization Features
90
- USE_REVERSIBLE = True
91
- USE_GRADIENT_CHECKPOINTING = True
92
- USE_MIXED_PRECISION = True
93
- USE_SAFETY_GATES = True
94
-
95
- # Dataset Configuration
96
- DATASET_NAME = "wikitext"
97
- DATASET_CONFIG = "wikitext-103-raw-v1"
98
- MAX_SAMPLES = None # Use full dataset
99
- STREAMING = True
100
-
101
- # Logging & Checkpointing
102
- LOG_INTERVAL = 50
103
- EVAL_INTERVAL = 1000
104
- CHECKPOINT_INTERVAL = 2000
105
-
106
- @classmethod
107
- def get_model_config(cls) -> Dict[str, Any]:
108
- """Get model configuration dictionary."""
109
- return {
110
- "d_model": cls.D_MODEL,
111
- "nhead": cls.NUM_HEADS,
112
- "num_layers": cls.NUM_LAYERS,
113
- "dim_feedforward": cls.DIM_FEEDFORWARD,
114
- "max_seq_len": cls.MAX_SEQ_LEN,
115
- "lambda_K": cls.LAMBDA_K,
116
- "lambda_C": cls.LAMBDA_C,
117
- "lambda_S": cls.LAMBDA_S,
118
- "reversible": cls.USE_REVERSIBLE,
119
- "use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING,
120
- "use_autocast": False, # Will use FSDP mixed precision instead
121
- "chunk_size": None, # Full attention for now
122
- "full_attn_logging": False, # Memory optimization
123
- }
124
-
125
-
126
- class WikiTextDataset(torch.utils.data.Dataset):
127
- """WikiText dataset preprocessed for bit-level training."""
128
-
129
- def __init__(self, split: str = "train", max_samples: Optional[int] = None,
130
- max_length: int = 2048, streaming: bool = True):
131
- self.max_length = max_length
132
- self.streaming = streaming
133
-
134
- logger.info(f"Loading WikiText-103 {split} split...")
135
- if streaming:
136
- self.dataset = load_dataset(
137
- MassiveScaleConfig.DATASET_NAME,
138
- MassiveScaleConfig.DATASET_CONFIG,
139
- split=split,
140
- streaming=True
141
- )
142
- if max_samples:
143
- self.dataset = self.dataset.take(max_samples)
144
- else:
145
- self.dataset = load_dataset(
146
- MassiveScaleConfig.DATASET_NAME,
147
- MassiveScaleConfig.DATASET_CONFIG,
148
- split=split
149
- )
150
- if max_samples:
151
- self.dataset = self.dataset.select(range(min(max_samples, len(self.dataset))))
152
-
153
- # Convert to list if not streaming for indexing
154
- if not streaming:
155
- self.texts = [item['text'] for item in self.dataset if len(item['text'].strip()) > 50]
156
- logger.info(f"Loaded {len(self.texts)} text samples from {split}")
157
- else:
158
- self.texts = None
159
- logger.info(f"Streaming dataset configured for {split}")
160
-
161
- def __len__(self) -> int:
162
- if self.texts is not None:
163
- return len(self.texts)
164
- else:
165
- # Rough estimate for streaming
166
- return 100000 if "train" in str(self.dataset) else 1000
167
-
168
- def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
169
- if self.texts is not None:
170
- text = self.texts[idx]
171
- else:
172
- # For streaming, we need to iterate
173
- for i, item in enumerate(self.dataset):
174
- if i == idx:
175
- text = item['text']
176
- break
177
- else:
178
- # Fallback
179
- text = "The quick brown fox jumps over the lazy dog."
180
-
181
- # Convert text to bits
182
- try:
183
- bits = text_to_bits(text)
184
-
185
- # Truncate or pad to max_length
186
- if len(bits) > self.max_length:
187
- bits = bits[:self.max_length]
188
- elif len(bits) < self.max_length:
189
- # Pad with zeros
190
- bits = bits + [0] * (self.max_length - len(bits))
191
-
192
- # Convert to tensor
193
- input_bits = torch.tensor(bits[:-1], dtype=torch.long) # Input sequence
194
- target_bits = torch.tensor(bits[1:], dtype=torch.long) # Shifted targets
195
-
196
- return {
197
- 'input_ids': input_bits,
198
- 'labels': target_bits,
199
- 'attention_mask': torch.ones_like(input_bits)
200
- }
201
-
202
- except Exception as e:
203
- logger.warning(f"Error processing text at index {idx}: {e}")
204
- # Fallback to simple bit pattern
205
- fallback_bits = [0, 1] * (self.max_length // 2)
206
- if len(fallback_bits) < self.max_length:
207
- fallback_bits.extend([0] * (self.max_length - len(fallback_bits)))
208
-
209
- input_bits = torch.tensor(fallback_bits[:-1], dtype=torch.long)
210
- target_bits = torch.tensor(fallback_bits[1:], dtype=torch.long)
211
-
212
- return {
213
- 'input_ids': input_bits,
214
- 'labels': target_bits,
215
- 'attention_mask': torch.ones_like(input_bits)
216
- }
217
-
218
-
219
- def setup_distributed(rank: int, world_size: int, port: str = "29500") -> None:
220
- """Initialize distributed training."""
221
- os.environ['MASTER_ADDR'] = 'localhost'
222
- os.environ['MASTER_PORT'] = port
223
- dist.init_process_group("nccl", rank=rank, world_size=world_size)
224
- torch.cuda.set_device(rank)
225
-
226
-
227
- def cleanup_distributed() -> None:
228
- """Clean up distributed training."""
229
- dist.destroy_process_group()
230
-
231
-
232
- def count_parameters(model: nn.Module) -> int:
233
- """Count total trainable parameters."""
234
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
235
-
236
-
237
- def create_fsdp_model(model_config: Dict[str, Any], rank: int) -> FSDP:
238
- """Create FSDP-wrapped BitTransformerLM model."""
239
-
240
- # Create base model
241
- model = BitTransformerLM(**model_config)
242
- model = model.to(rank)
243
-
244
- # Configure mixed precision
245
- mixed_precision_policy = MixedPrecision(
246
- param_dtype=torch.float16,
247
- reduce_dtype=torch.float16,
248
- buffer_dtype=torch.float16,
249
- )
250
-
251
- # Configure auto-wrap policy based on parameter size
252
- auto_wrap_policy = size_based_auto_wrap_policy
253
-
254
- # Wrap with FSDP
255
- model = FSDP(
256
- model,
257
- auto_wrap_policy=auto_wrap_policy,
258
- mixed_precision=mixed_precision_policy,
259
- backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
260
- device_id=rank,
261
- limit_all_gathers=True,
262
- )
263
-
264
- return model
265
-
266
-
267
- def log_training_stats(step: int, loss: float, telemetry: Dict[str, float],
268
- learning_rate: float, samples_per_sec: float,
269
- memory_allocated: float, rank: int) -> None:
270
- """Log training statistics."""
271
- if rank == 0:
272
- logger.info(
273
- f"Step {step:6d} | "
274
- f"Loss: {loss:.4f} | "
275
- f"K: {telemetry.get('negentropy', 0):.3f} | "
276
- f"C: {telemetry.get('lz_complexity', 0):.3f} | "
277
- f"S: {telemetry.get('symbiosis', 0):.3f} | "
278
- f"LR: {learning_rate:.2e} | "
279
- f"Speed: {samples_per_sec:.1f} samples/s | "
280
- f"Memory: {memory_allocated:.1f}GB"
281
- )
282
-
283
-
284
- def save_checkpoint(model: FSDP, optimizer, scheduler, step: int, loss: float,
285
- config: MassiveScaleConfig, rank: int) -> None:
286
- """Save model checkpoint."""
287
- if rank == 0:
288
- checkpoint_dir = f"/data/checkpoints/massive_scale_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
289
- os.makedirs(checkpoint_dir, exist_ok=True)
290
-
291
- # Save FSDP state dict
292
- with FSDP.state_dict_type(model, FSDP.StateDictType.FULL_STATE_DICT):
293
- model_state = model.state_dict()
294
-
295
- checkpoint = {
296
- 'step': step,
297
- 'model_state_dict': model_state,
298
- 'optimizer_state_dict': optimizer.state_dict(),
299
- 'scheduler_state_dict': scheduler.state_dict(),
300
- 'loss': loss,
301
- 'config': config.get_model_config(),
302
- 'timestamp': datetime.now().isoformat(),
303
- 'parameters': count_parameters(model),
304
- }
305
-
306
- checkpoint_path = f"{checkpoint_dir}/checkpoint_step_{step:06d}.pt"
307
- torch.save(checkpoint, checkpoint_path)
308
- logger.info(f"Checkpoint saved: {checkpoint_path}")
309
-
310
-
311
- def train_one_epoch(model: FSDP, train_loader: DataLoader, optimizer, scheduler,
312
- config: MassiveScaleConfig, epoch: int, rank: int, world_size: int) -> Tuple[float, Dict[str, float]]:
313
- """Train for one epoch."""
314
- model.train()
315
- set_dropout(model, 0.1)
316
-
317
- total_loss = 0
318
- step = 0
319
- start_time = time.time()
320
-
321
- for batch_idx, batch in enumerate(train_loader):
322
- if step >= config.MAX_STEPS:
323
- break
324
-
325
- # Move batch to device
326
- input_ids = batch['input_ids'].to(rank)
327
- labels = batch['labels'].to(rank)
328
- attention_mask = batch['attention_mask'].to(rank)
329
-
330
- # Forward pass
331
- optimizer.zero_grad()
332
-
333
- with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION):
334
- logits, telemetry = model(input_ids)
335
-
336
- # Compute loss
337
- loss = F.cross_entropy(
338
- logits.view(-1, 2),
339
- labels.view(-1),
340
- reduction='mean'
341
- )
342
-
343
- # Add telemetry losses
344
- if config.USE_SAFETY_GATES:
345
- negentropy = telemetry.get('negentropy', 0)
346
- lz_complexity = telemetry.get('lz_complexity', 0)
347
- symbiosis = telemetry.get('symbiosis', 0)
348
-
349
- # Apply safety gates
350
- if (negentropy < config.NEGENTROPY_THRESHOLD or
351
- lz_complexity < config.LZ_COMPLEXITY_THRESHOLD or
352
- symbiosis < config.SYMBIOSIS_THRESHOLD):
353
-
354
- safety_penalty = 10.0 # Strong penalty for unsafe outputs
355
- loss = loss + safety_penalty
356
-
357
- if rank == 0:
358
- logger.warning(f"Safety gate triggered at step {step}!")
359
-
360
- # Scale loss for gradient accumulation
361
- loss = loss / config.GRADIENT_ACCUMULATION_STEPS
362
-
363
- # Backward pass
364
- loss.backward()
365
-
366
- # Gradient accumulation
367
- if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0:
368
- # Gradient clipping
369
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
370
-
371
- # Optimizer step
372
- optimizer.step()
373
- scheduler.step()
374
-
375
- # Logging
376
- if step % config.LOG_INTERVAL == 0:
377
- # Calculate metrics
378
- samples_per_sec = (config.BATCH_SIZE_PER_GPU * world_size *
379
- config.LOG_INTERVAL) / (time.time() - start_time + 1e-7)
380
- memory_allocated = torch.cuda.memory_allocated(rank) / (1024**3)
381
-
382
- log_training_stats(
383
- step, loss.item() * config.GRADIENT_ACCUMULATION_STEPS,
384
- telemetry, scheduler.get_last_lr()[0], samples_per_sec,
385
- memory_allocated, rank
386
- )
387
-
388
- start_time = time.time()
389
-
390
- # Checkpointing
391
- if step % config.CHECKPOINT_INTERVAL == 0 and step > 0:
392
- save_checkpoint(
393
- model, optimizer, scheduler, step,
394
- loss.item() * config.GRADIENT_ACCUMULATION_STEPS,
395
- config, rank
396
- )
397
-
398
- step += 1
399
- total_loss += loss.item() * config.GRADIENT_ACCUMULATION_STEPS
400
-
401
- avg_loss = total_loss / max(step, 1)
402
- return avg_loss, telemetry
403
-
404
-
405
- def validate_model(model: FSDP, val_loader: DataLoader, config: MassiveScaleConfig,
406
- rank: int) -> Tuple[float, Dict[str, float]]:
407
- """Validate model performance."""
408
- model.eval()
409
- set_dropout(model, 0.0)
410
-
411
- total_loss = 0
412
- total_samples = 0
413
- accumulated_telemetry = {}
414
-
415
- with torch.no_grad():
416
- for batch in val_loader:
417
- if total_samples >= 1000: # Limit validation samples
418
- break
419
-
420
- input_ids = batch['input_ids'].to(rank)
421
- labels = batch['labels'].to(rank)
422
-
423
- with torch.cuda.amp.autocast(enabled=config.USE_MIXED_PRECISION):
424
- logits, telemetry = model(input_ids)
425
- loss = F.cross_entropy(
426
- logits.view(-1, 2),
427
- labels.view(-1),
428
- reduction='mean'
429
- )
430
-
431
- total_loss += loss.item() * input_ids.size(0)
432
- total_samples += input_ids.size(0)
433
-
434
- # Accumulate telemetry
435
- for key, value in telemetry.items():
436
- if key in accumulated_telemetry:
437
- accumulated_telemetry[key] += value
438
- else:
439
- accumulated_telemetry[key] = value
440
-
441
- avg_loss = total_loss / max(total_samples, 1)
442
-
443
- # Average telemetry
444
- for key in accumulated_telemetry:
445
- accumulated_telemetry[key] /= max(total_samples, 1)
446
-
447
- return avg_loss, accumulated_telemetry
448
-
449
-
450
- def main_worker(rank: int, world_size: int, config: MassiveScaleConfig) -> None:
451
- """Main training worker process."""
452
-
453
- setup_distributed(rank, world_size)
454
-
455
- if rank == 0:
456
- logger.info("🚀 MASSIVE SCALE BITTRANSFORMERLM TRAINING INITIATED!")
457
- logger.info(f"Target: {count_parameters(BitTransformerLM(**config.get_model_config())):,} parameters")
458
- logger.info(f"Hardware: {world_size}x NVIDIA L4 GPUs")
459
- logger.info(f"Configuration: {config.get_model_config()}")
460
-
461
- # Create datasets
462
- train_dataset = WikiTextDataset("train", max_samples=config.MAX_SAMPLES,
463
- max_length=config.MAX_SEQ_LEN, streaming=config.STREAMING)
464
- val_dataset = WikiTextDataset("validation", max_samples=1000,
465
- max_length=config.MAX_SEQ_LEN, streaming=False)
466
-
467
- # Create data loaders
468
- train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
469
- train_loader = DataLoader(
470
- train_dataset,
471
- batch_size=config.BATCH_SIZE_PER_GPU,
472
- sampler=train_sampler,
473
- num_workers=4,
474
- pin_memory=True
475
- )
476
-
477
- val_loader = DataLoader(
478
- val_dataset,
479
- batch_size=config.BATCH_SIZE_PER_GPU,
480
- shuffle=False,
481
- num_workers=2,
482
- pin_memory=True
483
- )
484
-
485
- # Create FSDP model
486
- model = create_fsdp_model(config.get_model_config(), rank)
487
-
488
- if rank == 0:
489
- param_count = count_parameters(model)
490
- logger.info(f"✅ Model created with {param_count:,} parameters ({param_count/1e9:.2f}B)")
491
-
492
- # Update benchmarks
493
- benchmark_update = f"""
494
-
495
- ### 🔥 LIVE RUN: 1.21B Parameter Training
496
- **Status:** ACTIVE
497
- **Started:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
498
- **Parameters:** {param_count:,} ({param_count/1e9:.2f}B)
499
- **Architecture:** d_model={config.D_MODEL}, layers={config.NUM_LAYERS}, heads={config.NUM_HEADS}
500
- **Effective Batch Size:** {config.EFFECTIVE_BATCH_SIZE}
501
- **Dataset:** WikiText-103 (streaming)
502
- **Hardware:** 4x NVIDIA L4 GPUs with FSDP
503
-
504
- """
505
- with open('/data/Benchmarks.md', 'a') as f:
506
- f.write(benchmark_update)
507
-
508
- # Create optimizer
509
- optimizer = torch.optim.AdamW(
510
- model.parameters(),
511
- lr=config.LEARNING_RATE,
512
- weight_decay=config.WEIGHT_DECAY,
513
- betas=(0.9, 0.95),
514
- )
515
-
516
- # Create scheduler
517
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
518
- optimizer,
519
- max_lr=config.LEARNING_RATE,
520
- total_steps=config.MAX_STEPS,
521
- pct_start=config.WARMUP_STEPS / config.MAX_STEPS,
522
- anneal_strategy='cos',
523
- )
524
-
525
- if rank == 0:
526
- logger.info("🎯 Starting training loop...")
527
-
528
- # Training loop
529
- try:
530
- for epoch in range(100): # Large number, will stop at MAX_STEPS
531
- train_sampler.set_epoch(epoch)
532
-
533
- train_loss, train_telemetry = train_one_epoch(
534
- model, train_loader, optimizer, scheduler,
535
- config, epoch, rank, world_size
536
- )
537
-
538
- if rank == 0:
539
- logger.info(f"📈 Epoch {epoch} completed - Average Loss: {train_loss:.4f}")
540
-
541
- # Validation
542
- val_loss, val_telemetry = validate_model(model, val_loader, config, rank)
543
- logger.info(f"📊 Validation Loss: {val_loss:.4f}")
544
-
545
- except KeyboardInterrupt:
546
- if rank == 0:
547
- logger.info("Training interrupted by user")
548
- except Exception as e:
549
- if rank == 0:
550
- logger.error(f"Training failed with error: {e}")
551
- raise
552
- finally:
553
- cleanup_distributed()
554
-
555
-
556
- def main():
557
- """Main entry point."""
558
- parser = argparse.ArgumentParser(description='BitTransformerLM Massive Scale Training')
559
- parser.add_argument('--world-size', type=int, default=4, help='Number of GPUs')
560
- parser.add_argument('--port', type=str, default='29500', help='Master port')
561
-
562
- args = parser.parse_args()
563
-
564
- config = MassiveScaleConfig()
565
-
566
- # Check CUDA availability
567
- if not torch.cuda.is_available():
568
- print("❌ CUDA not available! This script requires GPU training.")
569
- sys.exit(1)
570
-
571
- if torch.cuda.device_count() < args.world_size:
572
- print(f"❌ Only {torch.cuda.device_count()} GPUs available, but {args.world_size} requested")
573
- sys.exit(1)
574
-
575
- print(f"🚀 Launching massive scale training on {args.world_size} GPUs...")
576
- print(f"📊 Target: 1.21 BILLION parameters")
577
- print(f"📚 Dataset: WikiText-103 (full corpus)")
578
- print(f"🔥 This is going to be EPIC!")
579
-
580
- # Launch distributed training
581
- mp.spawn(
582
- main_worker,
583
- args=(args.world_size, config),
584
- nprocs=args.world_size,
585
- join=True
586
- )
587
-
588
-
589
- if __name__ == "__main__":
590
- main()