|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
from torch.utils.data import DataLoader
|
|
|
from tqdm import tqdm
|
|
|
import math
|
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
|
|
|
|
|
|
|
TRAIN_SEQ_LEN = 256
|
|
|
BATCH_SIZE = 12
|
|
|
EPOCHS = 10
|
|
|
LEARNING_RATE = 1e-6
|
|
|
WEIGHT_DECAY = 0.01
|
|
|
GRAD_CLIP = 0.5
|
|
|
VAL_SPLIT_RATIO = 0.05
|
|
|
|
|
|
BASE_MODEL_PATH = Path("models/JiRack_H4_L2_V50257_D768_MSL8192_FF768x4.script.pt")
|
|
|
DATASET_PATH = Path("datasets/dialogues_text_clean.txt")
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Устройство: {device}")
|
|
|
|
|
|
def print_model_devices(model):
|
|
|
sd = model.state_dict()
|
|
|
devs = set()
|
|
|
for k, v in sd.items():
|
|
|
try:
|
|
|
devs.add(v.device)
|
|
|
except Exception:
|
|
|
devs.add(torch.device('cpu'))
|
|
|
print("Devices present in model.state_dict():", devs)
|
|
|
return devs
|
|
|
|
|
|
def safe_load_jit_model(path: Path, map_device: torch.device):
|
|
|
"""
|
|
|
Загружает JIT модель с map_location и пытается привести её к map_device.
|
|
|
Возвращает (model, model_device) — модель и устройство, на котором находятся её параметры/буферы.
|
|
|
"""
|
|
|
if not path.exists():
|
|
|
raise FileNotFoundError(f"JIT model not found: {path}")
|
|
|
|
|
|
|
|
|
print(f"Loading JIT model from {path} with map_location={map_device} ...")
|
|
|
model = torch.jit.load(str(path), map_location=str(map_device))
|
|
|
print("Loaded model. Попытка model.to(...) ...")
|
|
|
try:
|
|
|
model = model.to(map_device)
|
|
|
print("model.to(map_device) выполнен.")
|
|
|
except Exception as e:
|
|
|
|
|
|
print("Warning: model.to(map_device) вызвал исключение:", e)
|
|
|
|
|
|
|
|
|
devs = print_model_devices(model)
|
|
|
|
|
|
|
|
|
if len(devs) == 0:
|
|
|
model_device = map_device
|
|
|
elif len(devs) == 1:
|
|
|
model_device = list(devs)[0]
|
|
|
else:
|
|
|
|
|
|
cuda_devs = [d for d in devs if 'cuda' in str(d)]
|
|
|
model_device = cuda_devs[0] if cuda_devs else list(devs)[0]
|
|
|
print("Внимание: обнаружены несколько устройств внутри state_dict(). Выбран model_device =", model_device)
|
|
|
|
|
|
|
|
|
if str(model_device) != str(map_device):
|
|
|
print(f"Model tensors are on {model_device} but requested map_device is {map_device}.")
|
|
|
print("Попробую заново загрузить модель с map_location=model_device ...")
|
|
|
try:
|
|
|
model = torch.jit.load(str(path), map_location=str(model_device))
|
|
|
try:
|
|
|
model = model.to(model_device)
|
|
|
except Exception:
|
|
|
pass
|
|
|
devs2 = print_model_devices(model)
|
|
|
if len(devs2) == 1 and list(devs2)[0] == model_device:
|
|
|
print("Успешно перезагружено на целевое устройство.")
|
|
|
except Exception as e:
|
|
|
print("Не удалось перезагрузить модель на желаемое устройство:", e)
|
|
|
|
|
|
return model, model_device
|
|
|
|
|
|
def get_logits_from_model(model, inputs):
|
|
|
"""
|
|
|
Вызов модели, допускающий возможные варианты возврата.
|
|
|
Мы предполагаем, что inputs уже находится на том же устройстве, что и модель.
|
|
|
"""
|
|
|
try:
|
|
|
out = model(inputs)
|
|
|
|
|
|
if isinstance(out, tuple) or isinstance(out, list):
|
|
|
return out[0]
|
|
|
return out
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
msg = str(e)
|
|
|
if "Expected all tensors to be on the same device" in msg or "but found at least two devices" in msg:
|
|
|
print("RuntimeError: вероятно есть mismatch устройств (cpu/cuda) внутри model. Диагностика state_dict():")
|
|
|
try:
|
|
|
print_model_devices(model)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
raise RuntimeError("Device mismatch while running the JIT model. See printed diagnostics above.") from e
|
|
|
else:
|
|
|
raise
|
|
|
|
|
|
|
|
|
def train():
|
|
|
model, model_device = safe_load_jit_model(BASE_MODEL_PATH, device)
|
|
|
|
|
|
|
|
|
from transformers import GPT2TokenizerFast
|
|
|
|
|
|
class DummyDataset(torch.utils.data.Dataset):
|
|
|
def __init__(self, n=1000, seq_len=TRAIN_SEQ_LEN, vocab_size=50257):
|
|
|
self.n = n
|
|
|
self.seq_len = seq_len
|
|
|
self.vocab_size = vocab_size
|
|
|
def __len__(self): return self.n
|
|
|
def __getitem__(self, i):
|
|
|
x = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
|
|
|
y = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
|
|
|
return x, y
|
|
|
|
|
|
train_dataset = DummyDataset(n=2000)
|
|
|
val_dataset = DummyDataset(n=200)
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
|
|
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
|
|
|
|
|
|
|
|
|
params = list(model.parameters()) if hasattr(model, 'parameters') else []
|
|
|
if len(params) == 0:
|
|
|
print("Warning: model.parameters() пуст. Убедитесь, что JIT-модель содержит параметры для оптимизации.")
|
|
|
optimizer = optim.AdamW(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) if params else None
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
|
scaler = GradScaler()
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
for epoch in range(1, EPOCHS + 1):
|
|
|
print(f"Эпоха {epoch}/{EPOCHS}")
|
|
|
epoch_loss = 0.0
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch} [TRAIN]", leave=False)
|
|
|
|
|
|
batch_count = 0
|
|
|
skipped_batches = 0
|
|
|
|
|
|
for xb, yb in pbar:
|
|
|
|
|
|
|
|
|
if torch.is_floating_point(xb) and (torch.isnan(xb).any() or torch.isinf(xb).any()):
|
|
|
print(f"\n[E{epoch}] WARNING: NaN or Inf found in input data (xb). Skipping batch.")
|
|
|
skipped_batches += 1
|
|
|
continue
|
|
|
|
|
|
|
|
|
xb = xb.to(model_device)
|
|
|
yb = yb.to(model_device)
|
|
|
|
|
|
if optimizer:
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
with autocast():
|
|
|
logits = get_logits_from_model(model, xb)
|
|
|
|
|
|
|
|
|
loss = criterion(logits.view(-1, logits.size(-1)), yb.view(-1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.isnan(loss) or torch.isinf(loss):
|
|
|
print(f"\n[E{epoch}] CRITICAL: Loss is NaN or Inf. Skipping backward and update.")
|
|
|
skipped_batches += 1
|
|
|
continue
|
|
|
|
|
|
|
|
|
scaler.scale(loss).backward()
|
|
|
|
|
|
if optimizer:
|
|
|
|
|
|
scaler.unscale_(optimizer)
|
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(params, GRAD_CLIP)
|
|
|
|
|
|
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
|
|
|
|
|
|
loss_val = loss.item()
|
|
|
epoch_loss += loss_val
|
|
|
batch_count += 1
|
|
|
|
|
|
pbar.set_postfix({"loss": f"{loss_val:.4f}", "ppl": f"{math.exp(min(loss_val, 10)):.2f}"})
|
|
|
|
|
|
|
|
|
avg_loss = epoch_loss / batch_count if batch_count > 0 else float('nan')
|
|
|
print(f"Средняя потеря за эпоху: {avg_loss:.4f}")
|
|
|
|
|
|
if skipped_batches > 0:
|
|
|
print(f"Внимание: {skipped_batches} батчей было пропущено из-за NaN/Inf в данных или лоссе.")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
train() |