import os import time import math import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from tokenizers import Tokenizer from novel_model import NovelTransformer # 配置参数 BATCH_SIZE = 4 # CUDA设备通常可以处理更大的批次 EPOCHS = 1240 LEARNING_RATE = 5e-5 VOCAB_SIZE = 8000 D_MODEL = 128 NHEAD = 4 NUM_LAYERS = 4 DIM_FEEDFORWARD = 512 DROPOUT = 0.1 MAX_LEN = 4096 SAVE_DIR = "./novel_model" DATA_PATH = "./novel_data/novel_dataset.pt" TOKENIZER_PATH = "./novel_tokenizer.json" # 创建保存目录 os.makedirs(SAVE_DIR, exist_ok=True) # 检查是否有检查点可以恢复 def load_checkpoint(model, optimizer, checkpoint_path): if os.path.exists(checkpoint_path): print(f"加载检查点: {checkpoint_path}") # 添加map_location参数,将模型映射到当前设备 checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 best_loss = checkpoint['loss'] print(f"从轮次 {start_epoch} 继续训练,最佳损失: {best_loss:.4f}") return start_epoch, best_loss else: print("没有找到检查点,从头开始训练") return 1, float('inf') # 加载数据 def load_data(): print("加载数据...") data = torch.load(DATA_PATH) # 创建输入和目标 # 输入: [0, 1, 2, ..., n-1] # 目标: [1, 2, 3, ..., n] src = data[:, :-1] tgt = data[:, 1:] return TensorDataset(src, tgt) # 创建数据加载器 def get_dataloader(dataset): return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True) # 训练函数 def train(model, dataloader, optimizer, criterion, device, epoch, scheduler=None): model.train() total_loss = 0 start_time = time.time() for batch_idx, (src, tgt) in enumerate(dataloader): src, tgt = src.to(device, non_blocking=True), tgt.to(device, non_blocking=True) optimizer.zero_grad() output = model(src) # 计算损失 (忽略padding) loss = criterion(output.view(-1, VOCAB_SIZE), tgt.reshape(-1)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() if scheduler is not None: scheduler.step() total_loss += loss.item() if batch_idx % 10 == 0: ms_per_batch = (time.time() - start_time) * 1000 / (batch_idx + 1) cur_loss = total_loss / (batch_idx + 1) print(f'| 轮次 {epoch:3d} | {batch_idx:5d}/{len(dataloader):5d} 批次 | ' f'损失 {cur_loss:5.2f} | {ms_per_batch:5.2f} ms/批次') return total_loss / len(dataloader) def main(): # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 设置CUDA相关优化 if device.type == 'cuda': torch.backends.cudnn.benchmark = True print(f"CUDA设备: {torch.cuda.get_device_name(0)}") print(f"CUDA版本: {torch.version.cuda}") print(f"可用GPU数量: {torch.cuda.device_count()}") # 加载数据 dataset = load_data() dataloader = get_dataloader(dataset) # 创建模型 model = NovelTransformer( vocab_size=VOCAB_SIZE, d_model=D_MODEL, nhead=NHEAD, num_layers=NUM_LAYERS, dim_feedforward=DIM_FEEDFORWARD, dropout=DROPOUT, max_len=MAX_LEN ) # 移动模型到设备 model = model.to(device) # 定义优化器和损失函数 optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE) criterion = nn.CrossEntropyLoss(ignore_index=0) # 假设0是padding # 学习率调度器 - 使用余弦退火调度 scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) # 尝试加载检查点 checkpoint_path = os.path.join(SAVE_DIR, 'latest_checkpoint.pt') start_epoch, best_loss = load_checkpoint(model, optimizer, checkpoint_path) # 训练循环 for epoch in range(start_epoch, EPOCHS + 1): epoch_start_time = time.time() train_loss = train(model, dataloader, optimizer, criterion, device, epoch, scheduler) print('-' * 89) print(f'| 轮次 {epoch:3d} | 时间: {time.time() - epoch_start_time:5.2f}s | ' f'训练损失 {train_loss:5.2f} | 训练困惑度 {math.exp(train_loss):8.2f}') print('-' * 89) # 保存最新检查点(用于恢复训练) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, }, checkpoint_path) # 保存模型 if train_loss < best_loss: best_loss = train_loss torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, }, os.path.join(SAVE_DIR, 'best_model.pt')) # 每个epoch都保存一次 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, }, os.path.join(SAVE_DIR, f'model_epoch_{epoch}.pt')) if __name__ == "__main__": main()