vAIbe_diffutslator / train.py
forthezero's picture
Upload 28 files
2651102 verified
"""
训练脚本
支持快速验证和完整训练,可暂停和恢复
"""
import os
import sys
import signal
import argparse
import time
from typing import Optional
from datetime import datetime
import torch
# 设置PyTorch使用所有CPU核心
torch.set_num_threads(os.cpu_count())
# 启用OpenMP并行
os.environ['OMP_NUM_THREADS'] = str(os.cpu_count())
os.environ['MKL_NUM_THREADS'] = str(os.cpu_count())
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from config import Config
from tokenizer import Tokenizer, train_tokenizers
from dataset import load_all_data, create_dataloaders
from embedding import DualLanguageEmbedding, DualOutputProjection
from model import create_model
from diffusion import get_diffusion, NoiseScheduler
from switcher import create_switcher
from utils import ProgressTracker, count_parameters, format_number, save_checkpoint, load_checkpoint
class Trainer:
"""训练器"""
def __init__(self, config: Config):
self.config = config
self.device = torch.device("cpu") # CPU训练
# 初始化组件
self._init_components()
# 训练状态
self.current_epoch = 0
self.global_step = 0
self.best_loss = float('inf')
self.should_stop = False
# 注册信号处理
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
def _init_components(self):
"""初始化所有组件"""
print("初始化训练组件...")
# 加载或训练分词器
tokenizer_path = os.path.join(self.config.project_dir, self.config.data.cache_dir)
zh_tokenizer_path = os.path.join(tokenizer_path, "tokenizer_zh.json")
en_tokenizer_path = os.path.join(tokenizer_path, "tokenizer_en.json")
if os.path.exists(zh_tokenizer_path) and os.path.exists(en_tokenizer_path):
print(" 加载已有分词器...")
self.zh_tokenizer = Tokenizer.load(zh_tokenizer_path)
self.en_tokenizer = Tokenizer.load(en_tokenizer_path)
else:
print(" 训练分词器...")
# 先加载数据用于训练分词器
train_pairs, _, _ = load_all_data(self.config)
zh_texts = [p.zh for p in train_pairs]
en_texts = [p.en for p in train_pairs]
self.zh_tokenizer, self.en_tokenizer = train_tokenizers(
self.config, zh_texts, en_texts
)
self.zh_tokenizer.save(zh_tokenizer_path)
self.en_tokenizer.save(en_tokenizer_path)
# 数据集
print(" 加载数据集...")
train_pairs, val_pairs, test_pairs = load_all_data(self.config)
self.train_loader, self.val_loader = create_dataloaders(
train_pairs, val_pairs,
self.zh_tokenizer, self.en_tokenizer,
self.config
)
# 嵌入层
print(" 初始化嵌入层...")
self.embedding = DualLanguageEmbedding(
vocab_size_zh=self.zh_tokenizer.vocab_size_actual,
vocab_size_en=self.en_tokenizer.vocab_size_actual,
d_model=self.config.model.d_model,
max_len=self.config.model.max_len,
dropout=self.config.model.dropout,
)
# 输出投影
self.output_proj = DualOutputProjection(
d_model=self.config.model.d_model,
vocab_size_zh=self.zh_tokenizer.vocab_size_actual,
vocab_size_en=self.en_tokenizer.vocab_size_actual,
)
# 噪声预测模型
print(" 初始化模型...")
self.model = create_model(self.config)
# 语言切换器
self.switcher = create_switcher(self.config)
# 扩散过程
self.diffusion, self.ddim_sampler = get_diffusion(self.config)
self.scheduler = self.diffusion.scheduler.to(self.device)
# 优化器
all_params = (
list(self.embedding.parameters()) +
list(self.output_proj.parameters()) +
list(self.model.parameters()) +
list(self.switcher.parameters())
)
self.optimizer = optim.AdamW(
all_params,
lr=self.config.training.learning_rate,
weight_decay=self.config.training.weight_decay,
)
# 学习率调度器
total_steps = len(self.train_loader) * self.config.training.epochs
self.lr_scheduler = OneCycleLR(
self.optimizer,
max_lr=self.config.training.learning_rate,
total_steps=total_steps,
pct_start=0.1,
anneal_strategy='cos',
)
# 损失函数
self.mse_loss = nn.MSELoss()
self.ce_loss = nn.CrossEntropyLoss()
# 打印模型信息
total_params = sum(count_parameters(m) for m in [self.embedding, self.output_proj, self.model, self.switcher])
print(f" 总参数量: {format_number(total_params)}")
def _signal_handler(self, signum, frame):
"""信号处理:保存模型并退出"""
print("\n\n收到中断信号,保存检查点...")
self._save_checkpoint("interrupted")
self.should_stop = True
def _save_checkpoint(self, name: str):
"""保存检查点"""
checkpoint_dir = os.path.join(self.config.project_dir, self.config.training.checkpoint_dir)
os.makedirs(checkpoint_dir, exist_ok=True)
path = os.path.join(checkpoint_dir, f"{name}.pt")
state = {
'epoch': self.current_epoch,
'global_step': self.global_step,
'best_loss': self.best_loss,
'embedding': self.embedding.state_dict(),
'output_proj': self.output_proj.state_dict(),
'model': self.model.state_dict(),
'switcher': self.switcher.state_dict(),
'optimizer': self.optimizer.state_dict(),
'lr_scheduler': self.lr_scheduler.state_dict(),
'config': self.config,
}
torch.save(state, path)
print(f" 检查点已保存: {path}")
def _load_checkpoint(self, path: str):
"""加载检查点"""
state = torch.load(path, map_location=self.device, weights_only=False)
self.current_epoch = state['epoch']
self.global_step = state['global_step']
self.best_loss = state['best_loss']
self.embedding.load_state_dict(state['embedding'])
self.output_proj.load_state_dict(state['output_proj'])
self.model.load_state_dict(state['model'])
self.switcher.load_state_dict(state['switcher'])
self.optimizer.load_state_dict(state['optimizer'])
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
print(f" 从检查点恢复: epoch={self.current_epoch}, step={self.global_step}")
def train_step(self, batch: dict) -> dict:
"""单步训练"""
# 获取数据
zh_ids = batch['zh_ids'].to(self.device)
en_ids = batch['en_ids'].to(self.device)
zh_lens = batch['zh_lens'].to(self.device)
en_lens = batch['en_lens'].to(self.device)
batch_size = zh_ids.size(0)
# 嵌入
zh_emb = self.embedding(zh_ids, 'zh', zh_lens)
en_emb = self.embedding(en_ids, 'en', en_lens)
# 随机时间步
t_zh = torch.randint(0, self.config.diffusion.timesteps, (batch_size,), device=self.device)
t_en = torch.randint(0, self.config.diffusion.timesteps, (batch_size,), device=self.device)
# 前向扩散
zh_noisy, zh_noise = self.diffusion.q_sample(zh_emb, t_zh)
en_noisy, en_noise = self.diffusion.q_sample(en_emb, t_en)
# 预测噪声
zh_noise_pred = self.model(zh_noisy, t_zh, lang='zh')
en_noise_pred = self.model(en_noisy, t_en, lang='en')
# 噪声预测损失
loss_noise_zh = self.mse_loss(zh_noise_pred, zh_noise)
loss_noise_en = self.mse_loss(en_noise_pred, en_noise)
# 语言切换损失
# 标签: 0=中文, 1=英文
zh_labels = torch.zeros(batch_size, dtype=torch.long, device=self.device)
en_labels = torch.ones(batch_size, dtype=torch.long, device=self.device)
zh_switch_logits = self.switcher(zh_noisy)
en_switch_logits = self.switcher(en_noisy)
loss_switch = (
self.ce_loss(zh_switch_logits, zh_labels) +
self.ce_loss(en_switch_logits, en_labels)
) / 2
# 总损失
loss = loss_noise_zh + loss_noise_en + 0.1 * loss_switch
# 反向传播(梯度累积)
loss = loss / self.config.training.gradient_accumulation
loss.backward()
return {
'loss': loss.item() * self.config.training.gradient_accumulation,
'loss_noise_zh': loss_noise_zh.item(),
'loss_noise_en': loss_noise_en.item(),
'loss_switch': loss_switch.item(),
}
def train_epoch(self, epoch: int) -> float:
"""训练一个epoch"""
self.model.train()
self.embedding.train()
self.output_proj.train()
self.switcher.train()
total_loss = 0
num_batches = len(self.train_loader)
tracker = ProgressTracker(
total_steps=num_batches,
desc=f"Epoch {epoch}/{self.config.training.epochs}"
)
batch_size = self.config.training.batch_size
for batch_idx, batch in enumerate(self.train_loader):
if self.should_stop:
break
# 训练步骤
metrics = self.train_step(batch)
total_loss += metrics['loss']
# 梯度累积
if (batch_idx + 1) % self.config.training.gradient_accumulation == 0:
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
list(self.embedding.parameters()) +
list(self.output_proj.parameters()) +
list(self.model.parameters()) +
list(self.switcher.parameters()),
1.0
)
# 更新参数
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
# 更新进度
tracker.update(batch_idx + 1, metrics['loss'])
# 每个batch都打印进度(实时反馈)
samples_speed = tracker.count * batch_size / tracker.elapsed if tracker.elapsed > 0 else 0
progress_str = tracker.format_progress(metrics['loss'])
progress_str = progress_str.replace("it/s", f"samples/s")
print(f"\r{progress_str} ({samples_speed:.0f} samples/s)", end="", flush=True)
print() # 换行
return total_loss / num_batches
@torch.no_grad()
def validate(self) -> float:
"""验证"""
self.model.eval()
self.embedding.eval()
self.output_proj.eval()
self.switcher.eval()
total_loss = 0
num_batches = min(len(self.val_loader), 50) # 限制验证步数
for batch_idx, batch in enumerate(self.val_loader):
if batch_idx >= num_batches:
break
zh_ids = batch['zh_ids'].to(self.device)
en_ids = batch['en_ids'].to(self.device)
zh_lens = batch['zh_lens'].to(self.device)
en_lens = batch['en_lens'].to(self.device)
batch_size = zh_ids.size(0)
# 嵌入
zh_emb = self.embedding(zh_ids, 'zh', zh_lens)
en_emb = self.embedding(en_ids, 'en', en_lens)
# 随机时间步
t = torch.randint(0, self.config.diffusion.timesteps, (batch_size,), device=self.device)
# 前向扩散
zh_noisy, zh_noise = self.diffusion.q_sample(zh_emb, t)
en_noisy, en_noise = self.diffusion.q_sample(en_emb, t)
# 预测噪声
zh_noise_pred = self.model(zh_noisy, t, lang='zh')
en_noise_pred = self.model(en_noisy, t, lang='en')
# 损失
loss = self.mse_loss(zh_noise_pred, zh_noise) + self.mse_loss(en_noise_pred, en_noise)
total_loss += loss.item()
return total_loss / num_batches
def train(self):
"""完整训练"""
print("\n" + "=" * 60)
print("开始训练")
print("=" * 60)
start_time = time.time()
for epoch in range(self.current_epoch + 1, self.config.training.epochs + 1):
if self.should_stop:
break
self.current_epoch = epoch
# 训练
train_loss = self.train_epoch(epoch)
# 验证
val_loss = self.validate()
# 打印结果
print(f"\nEpoch {epoch} 完成:")
print(f" 训练损失: {train_loss:.4f}")
print(f" 验证损失: {val_loss:.4f}")
# 保存检查点
if epoch % self.config.training.save_every == 0:
self._save_checkpoint(f"epoch_{epoch}")
# 保存最佳模型
if val_loss < self.best_loss:
self.best_loss = val_loss
self._save_checkpoint("best")
print(" 新的最佳模型!")
# 训练完成
elapsed = time.time() - start_time
print("\n" + "=" * 60)
print(f"训练完成! 总用时: {elapsed/60:.1f} 分钟")
print(f"最佳验证损失: {self.best_loss:.4f}")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(description="Diffutslator 训练脚本")
# 模式
parser.add_argument("--quick", action="store_true", help="快速验证模式")
parser.add_argument("--full", action="store_true", help="完整训练模式")
# 参数覆盖
parser.add_argument("--samples", type=int, default=None, help="使用的数据量")
parser.add_argument("--epochs", type=int, default=None, help="训练轮数")
parser.add_argument("--batch-size", type=int, default=None, help="批量大小")
parser.add_argument("--resume", type=str, default=None, help="恢复训练的检查点路径")
args = parser.parse_args()
# 创建配置
if args.quick:
config = Config.quick()
print("模式: 快速验证")
else:
config = Config()
print("模式: 完整训练")
# 覆盖参数
if args.samples:
config.data.max_samples = args.samples
if args.epochs:
config.training.epochs = args.epochs
if args.batch_size:
config.training.batch_size = args.batch_size
if args.resume:
config.training.resume = args.resume
# 打印配置
print(f"\n配置:")
print(f" 数据量: {config.data.max_samples or '全部'}")
print(f" 批量大小: {config.training.batch_size}")
print(f" 梯度累积: {config.training.gradient_accumulation}")
print(f" 有效批量: {config.training.batch_size * config.training.gradient_accumulation}")
print(f" 训练轮数: {config.training.epochs}")
print(f" 学习率: {config.training.learning_rate}")
# 创建训练器
trainer = Trainer(config)
# 恢复训练
if config.training.resume:
trainer._load_checkpoint(config.training.resume)
# 开始训练
trainer.train()
if __name__ == "__main__":
main()