JiRackTernary_3b / train_3b_heavy_mixed_val_data.py
kgrabko's picture
Upload 12 files
bc0c3f5 verified
# ==============================================================================
# COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
#
# This software is licensed under the Commercial License Agreement V.1.2.
# Any use, modification, or distribution of this code requires compliance with
# the terms found in the LICENSE.md file in the root directory.
#
# NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
# based on the BRE or SWA architectures disclosed herein.
# Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
# ==============================================================================
##
## Mix dataset with The Pile and custom cultural data for fine-tuning. to make priority to client data.
##
import torch
import random
import json
from torch.utils.data import IterableDataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AdamW
from datasets import load_dataset # Это библиотека для загрузки The Pile
# ==========================================
# КОНФИГУРАЦИЯ
# ==========================================
MODEL_ID = "./models/ternary_3b_init"
# Ссылка на датасет The Pile в Hugging Face Hub
GENERAL_DATA_LINK = "monology/pile-uncopyrighted"
# Твой локальный файл с культурным кодом
CLIENT_DATA_FILE = "cultural_finetune.jsonl"
MIX_RATIO = 0.3 # 30% - твой код, 70% - The Pile
BATCH_SIZE = 4
LEARNING_RATE = 1e-5
MAX_LENGTH = 512
# ==========================================
# МЕХАНИЗМ ПОДМЕШИВАНИЯ (MIXER)
# ==========================================
class CMSDataMixer(IterableDataset):
def __init__(self, tokenizer, client_file, pile_link):
self.tokenizer = tokenizer
# ЯВНАЯ ЗАГРУЗКА THE PILE (Стриминг, чтобы не качать 800ГБ сразу)
print(f">>> Подключаюсь к общему датасету: {pile_link}")
self.pile_stream = load_dataset(pile_link, split="train", streaming=True)
# ЗАГРУЗКА ТВОЕГО КУЛЬТУРНОГО КОДА
print(f">>> Загружаю культурный код клиента из: {client_file}")
self.cultural_data = []
with open(client_file, 'r', encoding='utf-8') as f:
for line in f:
self.cultural_data.append(json.loads(line))
print(f">>> Миксер готов: {MIX_RATIO*100}% данных будет из {client_file}")
def __iter__(self):
pile_iterator = iter(self.pile_stream)
while True:
# Бросаем кубик: что дать модели на этом шаге?
if random.random() < MIX_RATIO:
# ВЫБИРАЕМ ТВОЙ КУЛЬТУРНЫЙ КОД
sample = random.choice(self.cultural_data)
text = f"Question: {sample['question']}\nAnswer: {sample['answer']}"
else:
# ВЫБИРАЕМ ОБЩИЕ ЗНАНИЯ (THE PILE)
try:
sample = next(pile_iterator)
text = sample['text']
except StopIteration:
# Если Pile закончился (что вряд ли), начинаем сначала
pile_iterator = iter(self.pile_stream)
continue
# Превращаем текст в понятные модели цифры (токены)
tokens = self.tokenizer(
text,
truncation=True,
max_length=MAX_LENGTH,
padding="max_length",
return_tensors="pt"
)
yield {
"input_ids": tokens["input_ids"].squeeze(0),
"labels": tokens["input_ids"].squeeze(0)
}
# ==========================================
# ОСНОВНОЙ ЦИКЛ ОБУЧЕНИЯ
# ==========================================
def run_training():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f">>> Использую устройство: {device}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Инициализация модели
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16
).to(device)
# Создаем наш «умный» загрузчик данных
train_dataset = CMSDataMixer(tokenizer, CLIENT_DATA_FILE, GENERAL_DATA_LINK)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
print(">>> Запуск Fine-tuning...")
model.train()
for step, batch in enumerate(train_loader):
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % 50 == 0:
print(f"Шаг {step} | Текущая ошибка (Loss): {loss.item():.4f}")
# Сохраняем модель каждые 500 шагов
if step % 500 == 0 and step > 0:
model.save_pretrained(f"./checkpoint_step_{step}")
print(f">>> Чекпоинт сохранен на шаге {step}")
if __name__ == "__main__":
run_training()