JiRack_empty / source_jit /fine_tune_jit_with_validation_H4_L2.py
kgrabko's picture
Upload 16 files
c88fe21 verified
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.
import os
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import math
from torch.cuda.amp import autocast, GradScaler # 👈 Добавлен импорт AMP
# Параметры (пример)
TRAIN_SEQ_LEN = 256
BATCH_SIZE = 12
EPOCHS = 10
LEARNING_RATE = 1e-6 # 👈 СНИЖЕНО ДЛЯ СТАБИЛЬНОСТИ
WEIGHT_DECAY = 0.01
GRAD_CLIP = 0.5
VAL_SPLIT_RATIO = 0.05
BASE_MODEL_PATH = Path("models/JiRack_H4_L2_V50257_D768_MSL8192_FF768x4.script.pt")
DATASET_PATH = Path("datasets/dialogues_text_clean.txt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Устройство: {device}")
def print_model_devices(model):
sd = model.state_dict()
devs = set()
for k, v in sd.items():
try:
devs.add(v.device)
except Exception:
devs.add(torch.device('cpu'))
print("Devices present in model.state_dict():", devs)
return devs
def safe_load_jit_model(path: Path, map_device: torch.device):
"""
Загружает JIT модель с map_location и пытается привести её к map_device.
Возвращает (model, model_device) — модель и устройство, на котором находятся её параметры/буферы.
"""
if not path.exists():
raise FileNotFoundError(f"JIT model not found: {path}")
# Попытка загрузки с map_location
print(f"Loading JIT model from {path} with map_location={map_device} ...")
model = torch.jit.load(str(path), map_location=str(map_device))
print("Loaded model. Попытка model.to(...) ...")
try:
model = model.to(map_device)
print("model.to(map_device) выполнен.")
except Exception as e:
# У некоторых JIT объектов .to() может не сработать — это нормально, продолжим диагностику
print("Warning: model.to(map_device) вызвал исключение:", e)
# Диагностика устройств, где лежат параметры/буферы
devs = print_model_devices(model)
# Выберем устройство "модели" — если их несколько, отдаём предпочтение CUDA если есть
if len(devs) == 0:
model_device = map_device
elif len(devs) == 1:
model_device = list(devs)[0]
else:
# если есть смешанные устройства — попробуем приоритет cuda, иначе первый в множестве
cuda_devs = [d for d in devs if 'cuda' in str(d)]
model_device = cuda_devs[0] if cuda_devs else list(devs)[0]
print("Внимание: обнаружены несколько устройств внутри state_dict(). Выбран model_device =", model_device)
# Если model_device не равен map_device — уведомим пользователя и попытаемся ещё раз загрузить с конкретным map_location
if str(model_device) != str(map_device):
print(f"Model tensors are on {model_device} but requested map_device is {map_device}.")
print("Попробую заново загрузить модель с map_location=model_device ...")
try:
model = torch.jit.load(str(path), map_location=str(model_device))
try:
model = model.to(model_device)
except Exception:
pass
devs2 = print_model_devices(model)
if len(devs2) == 1 and list(devs2)[0] == model_device:
print("Успешно перезагружено на целевое устройство.")
except Exception as e:
print("Не удалось перезагрузить модель на желаемое устройство:", e)
# продолжаем, но предупредим пользователя
return model, model_device
def get_logits_from_model(model, inputs):
"""
Вызов модели, допускающий возможные варианты возврата.
Мы предполагаем, что inputs уже находится на том же устройстве, что и модель.
"""
try:
out = model(inputs)
# model может вернуть logits или (logits, kv)
if isinstance(out, tuple) or isinstance(out, list):
return out[0]
return out
except RuntimeError as e:
# Если ошибка связана с устройствами, добавим детальный лог
msg = str(e)
if "Expected all tensors to be on the same device" in msg or "but found at least two devices" in msg:
print("RuntimeError: вероятно есть mismatch устройств (cpu/cuda) внутри model. Диагностика state_dict():")
try:
print_model_devices(model)
except Exception:
pass
# Ребросим исключение с более понятным сообщением
raise RuntimeError("Device mismatch while running the JIT model. See printed diagnostics above.") from e
else:
raise
# ----------------- Пример интеграции в train loop -----------------
def train():
model, model_device = safe_load_jit_model(BASE_MODEL_PATH, device)
# Подготовьте датасеты здесь как вы уже делаете (замените на свой TextDataset)
from transformers import GPT2TokenizerFast
# Замените на ваш реальный TextDataset; здесь лишь заглушка
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, n=1000, seq_len=TRAIN_SEQ_LEN, vocab_size=50257):
self.n = n
self.seq_len = seq_len
self.vocab_size = vocab_size
def __len__(self): return self.n
def __getitem__(self, i):
x = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
y = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
return x, y
train_dataset = DummyDataset(n=2000)
val_dataset = DummyDataset(n=200)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
# Создаём optimizer
params = list(model.parameters()) if hasattr(model, 'parameters') else []
if len(params) == 0:
print("Warning: model.parameters() пуст. Убедитесь, что JIT-модель содержит параметры для оптимизации.")
optimizer = optim.AdamW(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) if params else None
criterion = nn.CrossEntropyLoss()
# Инициализация GradScaler для AMP
scaler = GradScaler()
model.train()
for epoch in range(1, EPOCHS + 1):
print(f"Эпоха {epoch}/{EPOCHS}")
epoch_loss = 0.0
pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]", leave=False)
batch_count = 0
skipped_batches = 0
for xb, yb in pbar:
# === 1. ПРОВЕРКА ДАННЫХ НА NAN/INF ===
# Проверяем только если тип данных — float (для LongTensor проверка не нужна)
if torch.is_floating_point(xb) and (torch.isnan(xb).any() or torch.isinf(xb).any()):
print(f"\n[E{epoch}] WARNING: NaN or Inf found in input data (xb). Skipping batch.")
skipped_batches += 1
continue
# Приводим батчи к устройству модели (model_device)
xb = xb.to(model_device)
yb = yb.to(model_device)
if optimizer:
optimizer.zero_grad()
# === 2. AMP: Выполняем forward-pass в half-precision ===
with autocast():
logits = get_logits_from_model(model, xb)
# У logits размер [B, seq_len, vocab] — приводим к числу классов
loss = criterion(logits.view(-1, logits.size(-1)), yb.view(-1))
# ========================================================
# === 3. ПРОВЕРКА ЛОССА НА NAN/INF ПЕРЕД BACKWARD ===
# Проверяем лосс, который теперь может быть float16 или float32
if torch.isnan(loss) or torch.isinf(loss):
print(f"\n[E{epoch}] CRITICAL: Loss is NaN or Inf. Skipping backward and update.")
skipped_batches += 1
continue
# AMP: Вычисляем градиенты, масштабируя их
scaler.scale(loss).backward()
if optimizer:
# AMP: Сначала снимаем масштаб
scaler.unscale_(optimizer)
# Обрезка градиентов
torch.nn.utils.clip_grad_norm_(params, GRAD_CLIP)
# AMP: Обновляем веса (scaler сам проверяет, не являются ли градиенты Inf/NaN)
scaler.step(optimizer)
scaler.update()
# Переводим лосс в float32 для записи и отображения
loss_val = loss.item()
epoch_loss += loss_val
batch_count += 1
pbar.set_postfix({"loss": f"{loss_val:.4f}", "ppl": f"{math.exp(min(loss_val, 10)):.2f}"})
# Средняя потеря считается только по не пропущенным батчам
avg_loss = epoch_loss / batch_count if batch_count > 0 else float('nan')
print(f"Средняя потеря за эпоху: {avg_loss:.4f}")
if skipped_batches > 0:
print(f"Внимание: {skipped_batches} батчей было пропущено из-за NaN/Inf в данных или лоссе.")
if __name__ == "__main__":
train()