t_n / pretrain.py
woywan's picture
Update pretrain.py
485c50e verified
import os
import time
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tokenizers import Tokenizer
from novel_model import NovelTransformer
# 配置参数
BATCH_SIZE = 4 # CUDA设备通常可以处理更大的批次
EPOCHS = 1240
LEARNING_RATE = 5e-5
VOCAB_SIZE = 8000
D_MODEL = 128
NHEAD = 4
NUM_LAYERS = 4
DIM_FEEDFORWARD = 512
DROPOUT = 0.1
MAX_LEN = 4096
SAVE_DIR = "./novel_model"
DATA_PATH = "./novel_data/novel_dataset.pt"
TOKENIZER_PATH = "./novel_tokenizer.json"
# 创建保存目录
os.makedirs(SAVE_DIR, exist_ok=True)
# 检查是否有检查点可以恢复
def load_checkpoint(model, optimizer, checkpoint_path):
if os.path.exists(checkpoint_path):
print(f"加载检查点: {checkpoint_path}")
# 添加map_location参数,将模型映射到当前设备
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_loss = checkpoint['loss']
print(f"从轮次 {start_epoch} 继续训练,最佳损失: {best_loss:.4f}")
return start_epoch, best_loss
else:
print("没有找到检查点,从头开始训练")
return 1, float('inf')
# 加载数据
def load_data():
print("加载数据...")
data = torch.load(DATA_PATH)
# 创建输入和目标
# 输入: [0, 1, 2, ..., n-1]
# 目标: [1, 2, 3, ..., n]
src = data[:, :-1]
tgt = data[:, 1:]
return TensorDataset(src, tgt)
# 创建数据加载器
def get_dataloader(dataset):
return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
# 训练函数
def train(model, dataloader, optimizer, criterion, device, epoch, scheduler=None):
model.train()
total_loss = 0
start_time = time.time()
for batch_idx, (src, tgt) in enumerate(dataloader):
src, tgt = src.to(device, non_blocking=True), tgt.to(device, non_blocking=True)
optimizer.zero_grad()
output = model(src)
# 计算损失 (忽略padding)
loss = criterion(output.view(-1, VOCAB_SIZE), tgt.reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
if scheduler is not None:
scheduler.step()
total_loss += loss.item()
if batch_idx % 10 == 0:
ms_per_batch = (time.time() - start_time) * 1000 / (batch_idx + 1)
cur_loss = total_loss / (batch_idx + 1)
print(f'| 轮次 {epoch:3d} | {batch_idx:5d}/{len(dataloader):5d} 批次 | '
f'损失 {cur_loss:5.2f} | {ms_per_batch:5.2f} ms/批次')
return total_loss / len(dataloader)
def main():
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 设置CUDA相关优化
if device.type == 'cuda':
torch.backends.cudnn.benchmark = True
print(f"CUDA设备: {torch.cuda.get_device_name(0)}")
print(f"CUDA版本: {torch.version.cuda}")
print(f"可用GPU数量: {torch.cuda.device_count()}")
# 加载数据
dataset = load_data()
dataloader = get_dataloader(dataset)
# 创建模型
model = NovelTransformer(
vocab_size=VOCAB_SIZE,
d_model=D_MODEL,
nhead=NHEAD,
num_layers=NUM_LAYERS,
dim_feedforward=DIM_FEEDFORWARD,
dropout=DROPOUT,
max_len=MAX_LEN
)
# 移动模型到设备
model = model.to(device)
# 定义优化器和损失函数
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=0) # 假设0是padding
# 学习率调度器 - 使用余弦退火调度
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# 尝试加载检查点
checkpoint_path = os.path.join(SAVE_DIR, 'latest_checkpoint.pt')
start_epoch, best_loss = load_checkpoint(model, optimizer, checkpoint_path)
# 训练循环
for epoch in range(start_epoch, EPOCHS + 1):
epoch_start_time = time.time()
train_loss = train(model, dataloader, optimizer, criterion, device, epoch, scheduler)
print('-' * 89)
print(f'| 轮次 {epoch:3d} | 时间: {time.time() - epoch_start_time:5.2f}s | '
f'训练损失 {train_loss:5.2f} | 训练困惑度 {math.exp(train_loss):8.2f}')
print('-' * 89)
# 保存最新检查点(用于恢复训练)
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
}, checkpoint_path)
# 保存模型
if train_loss < best_loss:
best_loss = train_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
}, os.path.join(SAVE_DIR, 'best_model.pt'))
# 每个epoch都保存一次
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
}, os.path.join(SAVE_DIR, f'model_epoch_{epoch}.pt'))
if __name__ == "__main__":
main()