kgrabko commited on
Commit
f13eb01
·
verified ·
1 Parent(s): 4e9d00e

Upload train_70b_heavy_mixed_val_data.py

Browse files
Files changed (1) hide show
  1. train_70b_heavy_mixed_val_data.py +151 -0
train_70b_heavy_mixed_val_data.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
3
+ # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
4
+ #
5
+ # This software is licensed under the Commercial License Agreement V.1.2.
6
+ # Any use, modification, or distribution of this code requires compliance with
7
+ # the terms found in the LICENSE.md file in the root directory.
8
+ #
9
+ # NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
10
+ # based on the BRE or SWA architectures disclosed herein.
11
+ # Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
12
+ # ==============================================================================
13
+
14
+ ##
15
+ ## Mix dataset with The Pile and custom cultural data for fine-tuning. to make priority to client data.
16
+ ##
17
+
18
+
19
+ import torch
20
+ import os
21
+ import random
22
+ import json
23
+ from torch.utils.data import DataLoader, IterableDataset
24
+ from transformers import AutoModelForCausalLM, AutoTokenizer
25
+ from datasets import load_dataset # Загрузка The Pile
26
+ import accelerate
27
+
28
+ # --- КОНФИГУРАЦИЯ ---
29
+ MODEL_ID = "./models/ternary_70b_init"
30
+ GENERAL_DATA_LINK = "monology/pile-uncopyrighted" # Ссылка на The Pile
31
+ CLIENT_DATA_FILE = "cultural_finetune.jsonl" # Твой эволюционный индекс
32
+ OUTPUT_DIR = "./models/checkpoints_70b"
33
+
34
+ MIX_RATIO = 0.4 # 40% - Культурный код, 60% - The Pile
35
+ LEARNING_RATE = 5e-6 # Ниже для стабильности 70B
36
+ SAVE_STEPS = 50 # Чаще сохраняем для Sidecar на тяжелых моделях
37
+
38
+ # --- ВСТРОЕННЫЙ МИКСЕР (Для автономности скрипта) ---
39
+ class CMSDataMixer(IterableDataset):
40
+ def __init__(self, tokenizer, client_file, pile_link, mix_ratio=0.4):
41
+ self.tokenizer = tokenizer
42
+ self.mix_ratio = mix_ratio
43
+
44
+ # Стриминг The Pile
45
+ print(f">>> [MIXER] Streaming general knowledge from: {pile_link}")
46
+ self.pile_stream = load_dataset(pile_link, split="train", streaming=True)
47
+
48
+ # Загрузка клиентских данных
49
+ self.cultural_data = []
50
+ if os.path.exists(client_file):
51
+ with open(client_file, 'r', encoding='utf-8') as f:
52
+ for line in f:
53
+ self.cultural_data.append(json.loads(line))
54
+ print(f">>> [MIXER] Loaded {len(self.cultural_data)} client samples.")
55
+ else:
56
+ print(f"⚠️ ERROR: {client_file} not found!")
57
+
58
+ def __iter__(self):
59
+ pile_iterator = iter(self.pile_stream)
60
+ while True:
61
+ if random.random() < self.mix_ratio and self.cultural_data:
62
+ sample = random.choice(self.cultural_data)
63
+ text = f"Question: {sample['question']}\nAnswer: {sample['answer']}"
64
+ else:
65
+ try:
66
+ sample = next(pile_iterator)
67
+ text = sample['text']
68
+ except StopIteration:
69
+ pile_iterator = iter(self.pile_stream)
70
+ continue
71
+
72
+ tokens = self.tokenizer(
73
+ text, truncation=True, max_length=512, padding="max_length", return_tensors="pt"
74
+ )
75
+ yield {
76
+ "input_ids": tokens["input_ids"].squeeze(0),
77
+ "labels": tokens["input_ids"].squeeze(0)
78
+ }
79
+
80
+ # --- ОСНОВНОЙ ЦИКЛ 70B ---
81
+ def train_heavy():
82
+ # Настройка акселератора для распределения 70B по кластеру Tesla M10
83
+ accelerator = accelerate.Accelerator(gradient_accumulation_steps=4)
84
+ device = accelerator.device
85
+
86
+ if not os.path.exists(OUTPUT_DIR):
87
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
88
+
89
+ # 1. Токенайзер
90
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
91
+ if tokenizer.pad_token is None:
92
+ tokenizer.pad_token = tokenizer.eos_token
93
+
94
+ # 2. Загрузка модели 70B
95
+ # device_map="auto" критически важен здесь для распределения слоев
96
+ print(f">>> Loading 70B model layers across GPUs using Accelerate...")
97
+ model = AutoModelForCausalLM.from_pretrained(
98
+ MODEL_ID,
99
+ device_map="auto",
100
+ torch_dtype=torch.bfloat16,
101
+ trust_remote_code=True
102
+ )
103
+
104
+ # 3. Инициализация миксера
105
+ dataset = CMSDataMixer(tokenizer, CLIENT_DATA_FILE, GENERAL_DATA_LINK, mix_ratio=MIX_RATIO)
106
+ loader = DataLoader(dataset, batch_size=1, pin_memory=True)
107
+
108
+ # Оптимизатор
109
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
110
+
111
+ # Подготовка через accelerate
112
+ model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
113
+
114
+ print(f">>> CMS Heavy Engine (70B) Started.")
115
+ print(f">>> Mixed Strategy: {int(MIX_RATIO*100)}% Client Focus / {int((1-MIX_RATIO)*100)}% Pile.")
116
+
117
+ model.train()
118
+ for step, batch in enumerate(loader):
119
+ try:
120
+ with accelerator.accumulate(model):
121
+ outputs = model(**batch)
122
+ loss = outputs.loss
123
+
124
+ if torch.isnan(loss):
125
+ print(f"⚠️ NaN loss at step {step}. Skipping...")
126
+ continue
127
+
128
+ accelerator.backward(loss)
129
+ optimizer.step()
130
+ optimizer.zero_grad()
131
+
132
+ if step % 10 == 0 and accelerator.is_main_process:
133
+ print(f"📊 Step {step} | Loss: {loss.item():.4f}")
134
+
135
+ # Сохранение для Sidecar
136
+ if step > 0 and step % SAVE_STEPS == 0 and accelerator.is_main_process:
137
+ save_path = os.path.join(OUTPUT_DIR, f"checkpoint_step_{step}")
138
+ print(f">>> Exporting 70B state: {save_path}")
139
+ accelerator.save_state(save_path)
140
+ torch.cuda.empty_cache()
141
+
142
+ except RuntimeError as e:
143
+ if "out of memory" in str(e):
144
+ print(f"❌ OOM on Step {step}. Clearing cache...")
145
+ torch.cuda.empty_cache()
146
+ continue
147
+ else:
148
+ raise e
149
+
150
+ if __name__ == "__main__":
151
+ train_heavy()