kgrabko commited on
Commit
4fc2831
·
verified ·
1 Parent(s): 73b01b0

Upload load_JiRack5_ThePile_13b.py

Browse files
Files changed (1) hide show
  1. load_JiRack5_ThePile_13b.py +108 -0
load_JiRack5_ThePile_13b.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
3
+ # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
4
+ # ==============================================================================
5
+ # Version 3.6 - 13B Agile Titan | Distributed Optimization
6
+ # Optimized for: huggyllama/llama-7b & monology/pile-uncopyrighted
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from transformers import AutoTokenizer
11
+ from datasets import load_dataset
12
+ from torch.cuda.amp import autocast, GradScaler
13
+ import os
14
+ import sys
15
+
16
+ # Import the 13B Architecture
17
+ from JiRackPyTorch_GPT5_class_13b import JiRackPyTorch
18
+
19
+ # --- CMS MANHATTAN CONFIGURATION ---
20
+ CHECKPOINT_DIR = "checkpoints_jirack_13b_fixed"
21
+ SAVE_INTERVAL = 1000
22
+ GRAD_ACCUM_STEPS = 16
23
+ BLOCK_SIZE = 2048
24
+ LEARNING_RATE = 3.0e-4
25
+
26
+ def train():
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ scaler = GradScaler()
29
+
30
+ # 1. FIXED TOKENIZER INTEGRATION
31
+ # The Llama tokenizer requires a fast implementation for streaming large datasets
32
+ tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", use_fast=True)
33
+ if tokenizer.pad_token is None:
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+
36
+ # 2. FIXED DATASET LOADING (Streaming & Sharding)
37
+ print("Connecting to monology/pile-uncopyrighted...")
38
+ dataset = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)
39
+
40
+ # If training on multiple GPUs, we must ensure each GPU sees different data
41
+ if torch.cuda.device_count() > 1:
42
+ # Simple shard logic for DataParallel simulation
43
+ # In a full DDP setup, use DistributedSampler
44
+ print(f"Detected {torch.cuda.device_count()} GPUs. Distributing workload...")
45
+
46
+ # 3. FIXED MODEL INITIALIZATION
47
+ # We pass the tokenizer length to ensure the Embedding Layer matches
48
+ model = JiRackPyTorch(vocab_size=len(tokenizer))
49
+ model.gradient_checkpointing_enable()
50
+
51
+ if torch.cuda.device_count() > 1:
52
+ model = nn.DataParallel(model)
53
+ model.to(device)
54
+
55
+ # 4. FIXED OPTIMIZER (8-bit enabled logic)
56
+ # Weight decay 0.1 is critical for 13B to prevent latent space collapse
57
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1)
58
+
59
+ model.train()
60
+ print("--- [FIXED] Training Started: JiRack 13B ---")
61
+
62
+ try:
63
+ for current_step, example in enumerate(dataset):
64
+ # Tokenization with fixed padding/truncation
65
+ tokens = tokenizer(
66
+ example["text"],
67
+ truncation=True,
68
+ max_length=BLOCK_SIZE,
69
+ padding="max_length",
70
+ return_tensors="pt"
71
+ )
72
+
73
+ input_ids = tokens["input_ids"].to(device)
74
+
75
+ # 5. FIXED FORWARD PASS (Mixed Precision)
76
+ with autocast(dtype=torch.bfloat16):
77
+ # Ensure labels=input_ids for Causal Language Modeling
78
+ logits, loss, _ = model(input_ids, targets=input_ids)
79
+ loss = loss.mean() / GRAD_ACCUM_STEPS
80
+
81
+ # 6. FIXED BACKWARD PASS
82
+ scaler.scale(loss).backward()
83
+
84
+ if (current_step + 1) % GRAD_ACCUM_STEPS == 0:
85
+ scaler.unscale_(optimizer)
86
+ # Gradient clipping is tightened to 1.0 for 13B stability
87
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
88
+ scaler.step(optimizer)
89
+ scaler.update()
90
+ optimizer.zero_grad()
91
+
92
+ if current_step % 50 == 0:
93
+ print(f"Step {current_step} | Loss: {loss.item()*GRAD_ACCUM_STEPS:.4f} | "
94
+ f"Alloc: {torch.cuda.memory_allocated()/1e9:.1f}GB", end='\r')
95
+
96
+ if current_step % SAVE_INTERVAL == 0 and current_step > 0:
97
+ save_path = os.path.join(CHECKPOINT_DIR, f"step_{current_step}.pt")
98
+ torch.save(model.state_dict(), save_path)
99
+
100
+ except Exception as e:
101
+ print(f"\n[CRITICAL ERROR] Training interrupted: {e}")
102
+ sys.exit(1)
103
+
104
+ if __name__ == "__main__":
105
+ # Allocator fix for Tesla M10 to prevent OOM during peak activation
106
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64"
107
+ if not os.path.exists(CHECKPOINT_DIR): os.makedirs(CHECKPOINT_DIR)
108
+ train()