WCNegentropy commited on
Commit
bd02910
·
verified ·
1 Parent(s): 984b78b

Remove true_1b_training.py - cleanup for OS launch

Browse files
Files changed (1) hide show
  1. true_1b_training.py +0 -485
true_1b_training.py DELETED
@@ -1,485 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- BitTransformerLM TRUE 1.21B Parameter Training
4
- ==============================================
5
-
6
- The REAL DEAL: 1.21B parameters with PROPER FSDP sharding (not duplication!)
7
- Based on our proven 680M success, now scaled to the full billion+ parameters!
8
- """
9
-
10
- import os
11
- import sys
12
- import time
13
- import json
14
- import logging
15
- import argparse
16
- import torch.multiprocessing as mp
17
- from datetime import datetime
18
- from typing import Dict, Any, Optional
19
-
20
- import torch
21
- import torch.nn as nn
22
- import torch.distributed as dist
23
- import torch.nn.functional as F
24
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
25
- from torch.distributed.fsdp import MixedPrecision, BackwardPrefetch, ShardingStrategy
26
- from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
27
- from torch.utils.data import DataLoader, DistributedSampler
28
- from datasets import load_dataset
29
-
30
- from bit_transformer.model import BitTransformerLM
31
- from bit_transformer.bit_io import text_to_bits
32
- from bit_transformer.utils import set_dropout
33
-
34
- # Configure logging
35
- logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
36
- logger = logging.getLogger(__name__)
37
-
38
-
39
- class True1BConfig:
40
- """TRUE 1.21B parameter configuration with optimized settings."""
41
-
42
- # Model Architecture - FULL 1.21B parameters
43
- D_MODEL = 2048
44
- NUM_LAYERS = 24
45
- NUM_HEADS = 32
46
- DIM_FEEDFORWARD = 8192
47
- MAX_SEQ_LEN = 512 # Optimized length from our 680M success
48
-
49
- # Training Configuration
50
- BATCH_SIZE_PER_GPU = 1 # Conservative
51
- NUM_GPUS = 4
52
- GRADIENT_ACCUMULATION_STEPS = 32
53
- EFFECTIVE_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS * GRADIENT_ACCUMULATION_STEPS # 128
54
-
55
- LEARNING_RATE = 2e-4
56
- WEIGHT_DECAY = 0.01
57
- MAX_STEPS = 1000 # Reasonable for demo
58
- WARMUP_STEPS = 100
59
-
60
- # OPTIMIZED BitTransformerLM settings (proven to work)
61
- USE_REVERSIBLE = True
62
- USE_GRADIENT_CHECKPOINTING = True
63
- USE_MIXED_PRECISION = True
64
- CHUNK_SIZE = 128 # Chunked attention for memory efficiency
65
- FULL_ATTN_LOGGING = False # Memory optimization
66
-
67
- # Reduced telemetry impact (proven necessary)
68
- LAMBDA_K = 0.1
69
- LAMBDA_C = 0.1
70
- LAMBDA_S = 0.1
71
-
72
- @classmethod
73
- def get_model_config(cls) -> Dict[str, Any]:
74
- """Get optimized model configuration."""
75
- return {
76
- "d_model": cls.D_MODEL,
77
- "nhead": cls.NUM_HEADS,
78
- "num_layers": cls.NUM_LAYERS,
79
- "dim_feedforward": cls.DIM_FEEDFORWARD,
80
- "max_seq_len": cls.MAX_SEQ_LEN,
81
- "lambda_K": cls.LAMBDA_K,
82
- "lambda_C": cls.LAMBDA_C,
83
- "lambda_S": cls.LAMBDA_S,
84
- "reversible": cls.USE_REVERSIBLE,
85
- "use_checkpoint": cls.USE_GRADIENT_CHECKPOINTING,
86
- "use_autocast": True,
87
- "chunk_size": cls.CHUNK_SIZE,
88
- "full_attn_logging": cls.FULL_ATTN_LOGGING,
89
- }
90
-
91
-
92
- class OptimizedWikiTextDataset(torch.utils.data.Dataset):
93
- """Optimized WikiText dataset for 1.21B training."""
94
-
95
- def __init__(self, split: str = "train", max_samples: int = 1000, max_length: int = 512):
96
- self.max_length = max_length
97
-
98
- logger.info(f"Loading WikiText-103 {split} (max {max_samples} samples)...")
99
- dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split=split)
100
-
101
- # Get good samples
102
- texts = [item['text'] for item in dataset
103
- if len(item['text'].strip()) > 50][:max_samples]
104
- self.texts = texts
105
-
106
- logger.info(f"Loaded {len(self.texts)} samples from {split}")
107
-
108
- def __len__(self) -> int:
109
- return len(self.texts)
110
-
111
- def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
112
- text = self.texts[idx]
113
-
114
- try:
115
- bits = text_to_bits(text)
116
- if len(bits) > self.max_length:
117
- bits = bits[:self.max_length]
118
- elif len(bits) < self.max_length:
119
- bits = bits + [0] * (self.max_length - len(bits))
120
-
121
- input_bits = torch.tensor(bits[:-1], dtype=torch.long)
122
- target_bits = torch.tensor(bits[1:], dtype=torch.long)
123
-
124
- return {
125
- 'input_ids': input_bits,
126
- 'labels': target_bits
127
- }
128
-
129
- except Exception:
130
- # Fallback pattern
131
- pattern = [0, 1] * (self.max_length // 2)
132
- input_bits = torch.tensor(pattern[:-1], dtype=torch.long)
133
- target_bits = torch.tensor(pattern[1:], dtype=torch.long)
134
-
135
- return {
136
- 'input_ids': input_bits,
137
- 'labels': target_bits
138
- }
139
-
140
-
141
- def setup_distributed(rank: int, world_size: int) -> None:
142
- """Setup distributed training."""
143
- os.environ['MASTER_ADDR'] = 'localhost'
144
- os.environ['MASTER_PORT'] = '29500'
145
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
146
-
147
- dist.init_process_group("nccl", rank=rank, world_size=world_size)
148
- torch.cuda.set_device(rank)
149
-
150
-
151
- def cleanup_distributed() -> None:
152
- """Cleanup distributed training."""
153
- dist.destroy_process_group()
154
-
155
-
156
- def create_fsdp_model(config: True1BConfig, rank: int) -> FSDP:
157
- """Create PROPERLY SHARDED FSDP model (not duplicated!)."""
158
-
159
- logger.info("🏗️ Creating TRUE 1.21B parameter model with PROPER FSDP sharding...")
160
- model_config = config.get_model_config()
161
-
162
- # Create model on CPU first
163
- model = BitTransformerLM(**model_config)
164
- params = sum(p.numel() for p in model.parameters())
165
-
166
- if rank == 0:
167
- logger.info(f"✅ Base model: {params:,} parameters ({params/1e9:.2f}B)")
168
-
169
- # PROPER FSDP configuration for SHARDING (not duplication)
170
- fsdp_config = {
171
- "auto_wrap_policy": size_based_auto_wrap_policy,
172
- "sharding_strategy": ShardingStrategy.FULL_SHARD, # FULL SHARDING!
173
- "mixed_precision": MixedPrecision(
174
- param_dtype=torch.float16,
175
- reduce_dtype=torch.float16,
176
- buffer_dtype=torch.float16,
177
- ),
178
- "backward_prefetch": BackwardPrefetch.BACKWARD_PRE,
179
- "device_id": rank,
180
- "limit_all_gathers": True,
181
- "use_orig_params": False, # Memory optimization
182
- }
183
-
184
- # Wrap with FSDP for SHARDING
185
- model = FSDP(model, **fsdp_config)
186
-
187
- if rank == 0:
188
- logger.info("✅ FSDP model created with FULL SHARDING (not duplication)")
189
- logger.info("🚀 Each GPU handles 1/4 of the 1.21B parameters!")
190
-
191
- return model
192
-
193
-
194
- def train_step(model: FSDP, batch: Dict[str, torch.Tensor],
195
- optimizer: torch.optim.Optimizer, scaler: torch.cuda.amp.GradScaler,
196
- rank: int) -> tuple:
197
- """Optimized training step."""
198
-
199
- model.train()
200
-
201
- input_ids = batch['input_ids'].to(rank, non_blocking=True)
202
- labels = batch['labels'].to(rank, non_blocking=True)
203
-
204
- with torch.cuda.amp.autocast():
205
- outputs = model(input_ids)
206
-
207
- if isinstance(outputs, tuple):
208
- logits, telemetry = outputs
209
- else:
210
- logits, telemetry = outputs, {}
211
-
212
- loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1))
213
-
214
- scaler.scale(loss).backward()
215
-
216
- return loss.item(), telemetry
217
-
218
-
219
- def save_checkpoint(model: FSDP, optimizer, scheduler, step: int,
220
- config: True1BConfig, rank: int) -> str:
221
- """Save 1.21B parameter checkpoint."""
222
- if rank == 0:
223
- checkpoint_dir = f"/data/checkpoints/true_1b_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
224
- os.makedirs(checkpoint_dir, exist_ok=True)
225
-
226
- # Save FSDP state dict
227
- with FSDP.state_dict_type(model, FSDP.StateDictType.FULL_STATE_DICT):
228
- model_state = model.state_dict()
229
-
230
- checkpoint = {
231
- 'step': step,
232
- 'model_state_dict': model_state,
233
- 'optimizer_state_dict': optimizer.state_dict(),
234
- 'scheduler_state_dict': scheduler.state_dict(),
235
- 'config': config.get_model_config(),
236
- 'timestamp': datetime.now().isoformat(),
237
- 'parameters': 1210000000, # Approximate
238
- }
239
-
240
- checkpoint_path = f"{checkpoint_dir}/model.pt"
241
- torch.save(checkpoint, checkpoint_path)
242
- logger.info(f"💾 1.21B model saved: {checkpoint_path}")
243
- return checkpoint_path
244
- return ""
245
-
246
-
247
- def test_inference(model: FSDP, config: True1BConfig, rank: int) -> Dict[str, Any]:
248
- """Test inference with the trained 1.21B model."""
249
- if rank != 0:
250
- return {}
251
-
252
- logger.info("🧪 Testing 1.21B parameter model inference...")
253
-
254
- model.eval()
255
- set_dropout(model, 0.0)
256
-
257
- inference_results = []
258
-
259
- # Test patterns
260
- test_patterns = [
261
- "Hello world",
262
- "The quick brown fox",
263
- "In the beginning",
264
- "Once upon a time",
265
- "Artificial intelligence"
266
- ]
267
-
268
- with torch.no_grad():
269
- for i, text in enumerate(test_patterns):
270
- try:
271
- # Convert to bits
272
- bits = text_to_bits(text)
273
- if len(bits) > config.MAX_SEQ_LEN - 50: # Leave room for generation
274
- bits = bits[:config.MAX_SEQ_LEN - 50]
275
-
276
- input_bits = torch.tensor(bits, dtype=torch.long).unsqueeze(0).to(rank)
277
-
278
- # Generate continuation
279
- with torch.cuda.amp.autocast():
280
- for _ in range(20): # Generate 20 more bits
281
- outputs = model(input_bits)
282
- if isinstance(outputs, tuple):
283
- logits, telemetry = outputs
284
- else:
285
- logits = outputs
286
- telemetry = {}
287
-
288
- # Get next bit prediction
289
- next_bit_logits = logits[0, -1, :]
290
- next_bit = torch.softmax(next_bit_logits, dim=-1).argmax().item()
291
-
292
- # Append to sequence
293
- next_tensor = torch.tensor([[next_bit]], dtype=torch.long).to(rank)
294
- input_bits = torch.cat([input_bits, next_tensor], dim=1)
295
-
296
- if input_bits.size(1) >= config.MAX_SEQ_LEN:
297
- break
298
-
299
- # Convert back to text
300
- generated_bits = input_bits.squeeze().cpu().tolist()
301
- try:
302
- generated_text = bits_to_text(generated_bits)
303
- except:
304
- generated_text = f"[Generated {len(generated_bits)} bits]"
305
-
306
- result = {
307
- 'input': text,
308
- 'input_bits': len(bits),
309
- 'generated_bits': len(generated_bits),
310
- 'output': generated_text[:200], # Limit length
311
- 'telemetry': {k: float(v) if isinstance(v, torch.Tensor) else v
312
- for k, v in telemetry.items()}
313
- }
314
-
315
- inference_results.append(result)
316
- logger.info(f"Test {i+1}: '{text}' -> Generated {len(generated_bits)} bits")
317
-
318
- except Exception as e:
319
- logger.warning(f"Inference test {i+1} failed: {e}")
320
- inference_results.append({
321
- 'input': text,
322
- 'error': str(e)
323
- })
324
-
325
- logger.info("✅ 1.21B model inference testing complete!")
326
- return {'inference_results': inference_results}
327
-
328
-
329
- def main_worker(rank: int, world_size: int, config: True1BConfig) -> None:
330
- """Main training worker for 1.21B model."""
331
-
332
- setup_distributed(rank, world_size)
333
-
334
- if rank == 0:
335
- logger.info("🚀 TRUE 1.21B PARAMETER BITTRANSFORMERLM TRAINING!")
336
- logger.info("=" * 60)
337
- logger.info("✅ PROPER FSDP SHARDING (not duplication)")
338
- logger.info("✅ Based on proven 680M success")
339
- logger.info("✅ All optimizations enabled")
340
-
341
- # Create datasets
342
- train_dataset = OptimizedWikiTextDataset("train", max_samples=2000, max_length=config.MAX_SEQ_LEN)
343
-
344
- train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
345
- train_loader = DataLoader(
346
- train_dataset,
347
- batch_size=config.BATCH_SIZE_PER_GPU,
348
- sampler=train_sampler,
349
- num_workers=0, # Avoid multiprocessing issues
350
- pin_memory=True
351
- )
352
-
353
- # Create FSDP model with PROPER sharding
354
- model = create_fsdp_model(config, rank)
355
-
356
- # Setup optimizer and scheduler
357
- optimizer = torch.optim.AdamW(
358
- model.parameters(),
359
- lr=config.LEARNING_RATE,
360
- weight_decay=config.WEIGHT_DECAY,
361
- betas=(0.9, 0.95)
362
- )
363
-
364
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
365
- optimizer,
366
- max_lr=config.LEARNING_RATE,
367
- total_steps=config.MAX_STEPS,
368
- pct_start=config.WARMUP_STEPS / config.MAX_STEPS,
369
- )
370
-
371
- scaler = torch.cuda.amp.GradScaler()
372
-
373
- if rank == 0:
374
- logger.info("🎯 Starting 1.21B parameter training...")
375
-
376
- # Training loop
377
- step = 0
378
- running_loss = 0.0
379
- start_time = time.time()
380
- checkpoint_path = ""
381
-
382
- try:
383
- for epoch in range(10):
384
- train_sampler.set_epoch(epoch)
385
-
386
- for batch_idx, batch in enumerate(train_loader):
387
- loss, telemetry = train_step(model, batch, optimizer, scaler, rank)
388
- running_loss += loss
389
-
390
- # Gradient accumulation
391
- if (batch_idx + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0:
392
- scaler.unscale_(optimizer)
393
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
394
- scaler.step(optimizer)
395
- scaler.update()
396
- scheduler.step()
397
- optimizer.zero_grad()
398
-
399
- step += 1
400
-
401
- # Logging
402
- if step % 10 == 0 and rank == 0:
403
- avg_loss = running_loss / 10
404
- elapsed = time.time() - start_time
405
- memory_used = torch.cuda.memory_allocated(rank) / (1024**3)
406
-
407
- logger.info(
408
- f"Step {step:4d} | "
409
- f"Loss: {avg_loss:.4f} | "
410
- f"K: {telemetry.get('negentropy', 0):.3f} | "
411
- f"C: {telemetry.get('lz_complexity', 0):.3f} | "
412
- f"S: {telemetry.get('symbiosis', 0):.3f} | "
413
- f"LR: {scheduler.get_last_lr()[0]:.2e} | "
414
- f"Mem: {memory_used:.1f}GB | "
415
- f"Time: {elapsed:.1f}s"
416
- )
417
-
418
- running_loss = 0.0
419
- start_time = time.time()
420
-
421
- # Save checkpoint
422
- if step % 100 == 0 and step > 0:
423
- checkpoint_path = save_checkpoint(model, optimizer, scheduler, step, config, rank)
424
-
425
- if step >= config.MAX_STEPS:
426
- break
427
-
428
- if step >= config.MAX_STEPS:
429
- break
430
-
431
- # Final checkpoint
432
- if rank == 0:
433
- checkpoint_path = save_checkpoint(model, optimizer, scheduler, step, config, rank)
434
- logger.info("🏆 1.21B PARAMETER TRAINING COMPLETED SUCCESSFULLY!")
435
-
436
- # Test inference
437
- inference_results = test_inference(model, config, rank)
438
-
439
- # Save results to benchmarks
440
- benchmark_data = {
441
- 'timestamp': datetime.now().isoformat(),
442
- 'model_parameters': '1.21B',
443
- 'training_steps': step,
444
- 'final_loss': running_loss,
445
- 'checkpoint_path': checkpoint_path,
446
- 'inference_results': inference_results,
447
- 'config': config.get_model_config(),
448
- }
449
-
450
- with open('/data/true_1b_results.json', 'w') as f:
451
- json.dump(benchmark_data, f, indent=2)
452
-
453
- logger.info("📊 Results saved to /data/true_1b_results.json")
454
-
455
- except Exception as e:
456
- if rank == 0:
457
- logger.error(f"Training failed: {e}")
458
- raise
459
- finally:
460
- cleanup_distributed()
461
-
462
-
463
- def main():
464
- """Main entry point."""
465
- config = True1BConfig()
466
- world_size = 4
467
-
468
- if not torch.cuda.is_available() or torch.cuda.device_count() < world_size:
469
- print("❌ Need 4 CUDA GPUs for 1.21B training!")
470
- return
471
-
472
- print("🚀 Launching TRUE 1.21B parameter training with PROPER FSDP sharding!")
473
- print("🎯 This will work because we've proven the hardware capability!")
474
-
475
- # Launch distributed training
476
- mp.spawn(
477
- main_worker,
478
- args=(world_size, config),
479
- nprocs=world_size,
480
- join=True
481
- )
482
-
483
-
484
- if __name__ == "__main__":
485
- main()