| | import os |
| | import time |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader, TensorDataset |
| | from tokenizers import Tokenizer |
| | from novel_model import NovelTransformer |
| |
|
| | |
| | BATCH_SIZE = 4 |
| | EPOCHS = 1240 |
| | LEARNING_RATE = 5e-5 |
| | VOCAB_SIZE = 8000 |
| | D_MODEL = 128 |
| | NHEAD = 4 |
| | NUM_LAYERS = 4 |
| | DIM_FEEDFORWARD = 512 |
| | DROPOUT = 0.1 |
| | MAX_LEN = 4096 |
| | SAVE_DIR = "./novel_model" |
| | DATA_PATH = "./novel_data/novel_dataset.pt" |
| | TOKENIZER_PATH = "./novel_tokenizer.json" |
| |
|
| | |
| | os.makedirs(SAVE_DIR, exist_ok=True) |
| |
|
| | |
| | def load_checkpoint(model, optimizer, checkpoint_path): |
| | if os.path.exists(checkpoint_path): |
| | print(f"加载检查点: {checkpoint_path}") |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | start_epoch = checkpoint['epoch'] + 1 |
| | best_loss = checkpoint['loss'] |
| | print(f"从轮次 {start_epoch} 继续训练,最佳损失: {best_loss:.4f}") |
| | return start_epoch, best_loss |
| | else: |
| | print("没有找到检查点,从头开始训练") |
| | return 1, float('inf') |
| |
|
| | |
| | def load_data(): |
| | print("加载数据...") |
| | data = torch.load(DATA_PATH) |
| | |
| | |
| | |
| | |
| | src = data[:, :-1] |
| | tgt = data[:, 1:] |
| | |
| | return TensorDataset(src, tgt) |
| |
|
| | |
| | def get_dataloader(dataset): |
| | return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True) |
| |
|
| | |
| | def train(model, dataloader, optimizer, criterion, device, epoch, scheduler=None): |
| | model.train() |
| | total_loss = 0 |
| | start_time = time.time() |
| | |
| | for batch_idx, (src, tgt) in enumerate(dataloader): |
| | src, tgt = src.to(device, non_blocking=True), tgt.to(device, non_blocking=True) |
| | |
| | optimizer.zero_grad() |
| | output = model(src) |
| | |
| | |
| | loss = criterion(output.view(-1, VOCAB_SIZE), tgt.reshape(-1)) |
| | loss.backward() |
| | |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) |
| | optimizer.step() |
| | |
| | if scheduler is not None: |
| | scheduler.step() |
| | |
| | total_loss += loss.item() |
| | |
| | if batch_idx % 10 == 0: |
| | ms_per_batch = (time.time() - start_time) * 1000 / (batch_idx + 1) |
| | cur_loss = total_loss / (batch_idx + 1) |
| | print(f'| 轮次 {epoch:3d} | {batch_idx:5d}/{len(dataloader):5d} 批次 | ' |
| | f'损失 {cur_loss:5.2f} | {ms_per_batch:5.2f} ms/批次') |
| | |
| | return total_loss / len(dataloader) |
| |
|
| | def main(): |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"使用设备: {device}") |
| | |
| | |
| | if device.type == 'cuda': |
| | torch.backends.cudnn.benchmark = True |
| | print(f"CUDA设备: {torch.cuda.get_device_name(0)}") |
| | print(f"CUDA版本: {torch.version.cuda}") |
| | print(f"可用GPU数量: {torch.cuda.device_count()}") |
| | |
| | |
| | dataset = load_data() |
| | dataloader = get_dataloader(dataset) |
| | |
| | |
| | model = NovelTransformer( |
| | vocab_size=VOCAB_SIZE, |
| | d_model=D_MODEL, |
| | nhead=NHEAD, |
| | num_layers=NUM_LAYERS, |
| | dim_feedforward=DIM_FEEDFORWARD, |
| | dropout=DROPOUT, |
| | max_len=MAX_LEN |
| | ) |
| | |
| | |
| | model = model.to(device) |
| | |
| | |
| | optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE) |
| | criterion = nn.CrossEntropyLoss(ignore_index=0) |
| | |
| | |
| | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) |
| | |
| | |
| | checkpoint_path = os.path.join(SAVE_DIR, 'latest_checkpoint.pt') |
| | start_epoch, best_loss = load_checkpoint(model, optimizer, checkpoint_path) |
| | |
| | |
| | for epoch in range(start_epoch, EPOCHS + 1): |
| | epoch_start_time = time.time() |
| | train_loss = train(model, dataloader, optimizer, criterion, device, epoch, scheduler) |
| | |
| | print('-' * 89) |
| | print(f'| 轮次 {epoch:3d} | 时间: {time.time() - epoch_start_time:5.2f}s | ' |
| | f'训练损失 {train_loss:5.2f} | 训练困惑度 {math.exp(train_loss):8.2f}') |
| | print('-' * 89) |
| | |
| | |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'loss': train_loss, |
| | }, checkpoint_path) |
| | |
| | |
| | if train_loss < best_loss: |
| | best_loss = train_loss |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'loss': train_loss, |
| | }, os.path.join(SAVE_DIR, 'best_model.pt')) |
| | |
| | |
| | torch.save({ |
| | 'epoch': epoch, |
| | 'model_state_dict': model.state_dict(), |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'loss': train_loss, |
| | }, os.path.join(SAVE_DIR, f'model_epoch_{epoch}.pt')) |
| |
|
| | if __name__ == "__main__": |
| | main() |