kgrabko commited on
Commit
4e9d00e
·
verified ·
1 Parent(s): ab8854e

Upload load_ThePile_800Gb_JiRackTernary_70b.py

Browse files
load_ThePile_800Gb_JiRackTernary_70b.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
3
+ # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
4
+ #
5
+ # This software is licensed under the Commercial License Agreement V.1.2.
6
+ # Any use, modification, or distribution of this code requires compliance with
7
+ # the terms found in the LICENSE.md file in the root directory.
8
+ #
9
+ # NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
10
+ # based on the BRE or SWA architectures disclosed herein.
11
+ # Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
12
+ # ==============================================================================
13
+
14
+ # Version 2.3 - Enterprise Signature & 70B Optimization
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from transformers import AutoTokenizer
19
+ from datasets import load_dataset
20
+ from torch.cuda.amp import autocast, GradScaler
21
+ import os
22
+ import time
23
+
24
+ # Импорт вашей архитектуры
25
+ from JiRackTernaryPyTorch_70b import JiRackTernary70B, JiRackTernaryConfig
26
+
27
+ # --- КОНФИГУРАЦИЯ CMS MANHATTAN ---
28
+ CHECKPOINT_DIR = "checkpoints_jirack"
29
+ MODEL_PATH_LATEST = os.path.join(CHECKPOINT_DIR, "jirack_70b_latest.pt")
30
+ SAVE_INTERVAL = 500
31
+ GRAD_ACCUM_STEPS = 16 # Увеличено для стабильности 70B
32
+ BLOCK_SIZE = 4096 # Поддержка RoPE Scaling
33
+ LEARNING_RATE = 1.0e-5 # Понижено для 70B (предотвращение взрыва лосса)
34
+
35
+ def save_checkpoint(model, optimizer, step, config):
36
+ if not os.path.exists(CHECKPOINT_DIR):
37
+ os.makedirs(CHECKPOINT_DIR)
38
+
39
+ raw_model = model.module if hasattr(model, 'module') else model
40
+ checkpoint = {
41
+ 'step': step,
42
+ 'model_state_dict': raw_model.state_dict(),
43
+ 'optimizer_state_dict': optimizer.state_dict(),
44
+ 'config': config,
45
+ 'author_verified': raw_model.get_author_info()
46
+ }
47
+
48
+ # Атомарное сохранение
49
+ temp_path = MODEL_PATH_LATEST + ".tmp"
50
+ torch.save(checkpoint, temp_path)
51
+ os.replace(temp_path, MODEL_PATH_LATEST)
52
+ print(f"\n[CMS Manhattan] Авторская копия сохранена на шаге {step}.")
53
+
54
+ def load_latest_checkpoint(model, optimizer):
55
+ if os.path.exists(MODEL_PATH_LATEST):
56
+ print(f"--- [RESUME] Поиск подписи автора... ---")
57
+ checkpoint = torch.load(MODEL_PATH_LATEST, map_location='cpu')
58
+
59
+ target = model.module if hasattr(model, 'module') else model
60
+ target.load_state_dict(checkpoint['model_state_dict'])
61
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
62
+
63
+ print(f"--- [OK] Автор: {checkpoint.get('author_verified', 'Unknown')} ---")
64
+ return checkpoint['step']
65
+ return 0
66
+
67
+ def train():
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ scaler = GradScaler() # Обязательно для тернарных весов
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
72
+
73
+ # Инициализация 70B с полными параметрами
74
+ config = JiRackTernaryConfig()
75
+ model = JiRackTernary70B(config)
76
+
77
+ # Активация экономии памяти
78
+ model.gradient_checkpointing_enable()
79
+
80
+ if torch.cuda.device_count() > 1:
81
+ model = nn.DataParallel(model)
82
+ model.to(device)
83
+
84
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1)
85
+
86
+ start_step = load_latest_checkpoint(model, optimizer)
87
+
88
+ print("Подключение к The Pile (Streaming)...")
89
+ raw_dataset = load_dataset("monology/pile-uncopyrighted", streaming=True, split="train")
90
+ dataset = raw_dataset.skip(start_step) if start_step > 0 else raw_dataset
91
+
92
+ model.train()
93
+ current_step = start_step
94
+
95
+ for example in dataset:
96
+ tokens = tokenizer(
97
+ example["text"],
98
+ truncation=True,
99
+ max_length=BLOCK_SIZE,
100
+ return_tensors="pt"
101
+ )
102
+ input_ids = tokens["input_ids"].to(device)
103
+
104
+ # Mixed Precision Training
105
+ with autocast():
106
+ outputs = model(input_ids, labels=input_ids)
107
+ loss = outputs.loss.mean() / GRAD_ACCUM_STEPS
108
+
109
+ scaler.scale(loss).backward()
110
+
111
+ if (current_step + 1) % GRAD_ACCUM_STEPS == 0:
112
+ scaler.unscale_(optimizer)
113
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
114
+ scaler.step(optimizer)
115
+ scaler.update()
116
+ optimizer.zero_grad()
117
+
118
+ if current_step % 5 == 0:
119
+ print(f"CMS 70B | Step {current_step} | Loss: {loss.item()*GRAD_ACCUM_STEPS:.4f} | Author: Grabko", end='\r')
120
+
121
+ if current_step % SAVE_INTERVAL == 0 and current_step > start_step:
122
+ save_checkpoint(model, optimizer, current_step, config)
123
+
124
+ current_step += 1
125
+
126
+ if __name__ == "__main__":
127
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
128
+ try:
129
+ train()
130
+ except KeyboardInterrupt:
131
+ print("\n[!] Остановка. Система CMS Manhattan в режиме ожидания.")