|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.optim import AdamW
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch.amp import autocast, GradScaler
|
|
|
from tqdm import tqdm
|
|
|
import math
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
from your_model_file import JiRack_H4_L2
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
print(f"Устройство: {device}")
|
|
|
|
|
|
|
|
|
model = JiRack_H4_L2().to(device)
|
|
|
|
|
|
|
|
|
state_dict = torch.load("models/JiRack_H4_L2_V50257_D768_MSL8192_FF3072.pt", map_location=device)
|
|
|
model.load_state_dict(state_dict)
|
|
|
print("Веса загружены из .pt файла")
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 12
|
|
|
SEQ_LEN = 256
|
|
|
EPOCHS = 10
|
|
|
LR = 5e-5
|
|
|
WARMUP_STEPS = 100
|
|
|
|
|
|
|
|
|
class DummyDataset(torch.utils.data.Dataset):
|
|
|
def __init__(self, n=10000): self.n = n
|
|
|
def __len__(self): return self.n
|
|
|
def __getitem__(self, i):
|
|
|
x = torch.randint(0, 50257, (SEQ_LEN,))
|
|
|
return x, x.roll(-1)
|
|
|
|
|
|
train_loader = DataLoader(DummyDataset(), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
|
|
|
scaler = GradScaler('cuda')
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
|
|
global_step = 0
|
|
|
model.train()
|
|
|
|
|
|
for epoch in range(1, EPOCHS + 1):
|
|
|
total_loss = 0
|
|
|
pbar = tqdm(train_loader, desc=f"Эпоха {epoch}/{EPOCHS}")
|
|
|
|
|
|
for xb, yb in pbar:
|
|
|
global_step += 1
|
|
|
xb, yb = xb.to(device), yb.to(device)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
with autocast('cuda'):
|
|
|
logits = model(xb)
|
|
|
loss = criterion(logits.view(-1, logits.size(-1)), yb.view(-1))
|
|
|
|
|
|
scaler.scale(loss).backward()
|
|
|
scaler.unscale_(optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
|
|
|
|
|
|
if global_step < WARMUP_STEPS:
|
|
|
lr_scale = global_step / WARMUP_STEPS
|
|
|
for pg in optimizer.param_groups:
|
|
|
pg['lr'] = LR * lr_scale
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}", "ppl": f"{math.exp(loss.item()):.1f}"})
|
|
|
|
|
|
avg_loss = total_loss / len(train_loader)
|
|
|
print(f"Эпоха {epoch} завершена | Средний loss: {avg_loss:.4f} | Perplexity: {math.exp(avg_loss):.2f}\n")
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), "models/JiRack_H4_L2_finetuned.pt")
|
|
|
|
|
|
|
|
|
class JITWrapper(nn.Module):
|
|
|
def __init__(self, m): super().__init__(); self.m = m
|
|
|
def forward(self, x): return self.m(x)
|
|
|
|
|
|
dummy = torch.randint(0, 50257, (1, 256), device=device)
|
|
|
traced = torch.jit.trace(JITWrapper(model.cpu().eval()), dummy)
|
|
|
traced.save("models/JiRack_H4_L2_finetuned.script.pt")
|
|
|
print("Обученная модель сохранена + экспортирована в JIT для инференса")
|
|
|
|