MultiModal / posttrain.py
szxllm's picture
Update posttrain.py
4084655 verified
import os
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from pathlib import Path
import logging
from tqdm import tqdm
import json
from datetime import datetime
import copy
from model import MultiModalDenseTransformer
from data_loader import (
create_posttrain_dataloader,
create_preference_dataloader
)
from data_config import POSTTRAIN_MIX
from reward_model import RewardModel, RewardModelTrainer
from grpo import GRPOTrainer
from typing import Optional
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
class PostTrainer:
def __init__(
self,
model: MultiModalDenseTransformer,
tokenizer,
learning_rate: float = 1e-5,
weight_decay: float = 0.01,
num_epochs: int = 3,
gradient_accumulation_steps: int = 1,
max_grad_norm: float = 1.0,
log_interval: int = 10,
eval_interval: int = 500,
save_interval: int = 1000,
checkpoint_dir: str = "checkpoints/posttrain"
):
self.model = model
self.tokenizer = tokenizer
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
# 优化器
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
betas=(0.9, 0.95),
eps=1e-8
)
# 混合精度
self.use_amp = torch.cuda.is_available()
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
# 训练参数
self.num_epochs = num_epochs
self.gradient_accumulation_steps = gradient_accumulation_steps
self.max_grad_norm = max_grad_norm
self.log_interval = log_interval
self.eval_interval = eval_interval
self.save_interval = save_interval
# Checkpoint管理
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# 训练状态
self.global_step = 0
self.best_eval_loss = float('inf')
logger.info(f"PostTrainer initialized:")
logger.info(f" Device: {self.device}")
logger.info(f" Learning Rate: {learning_rate}")
logger.info(f" Num Epochs: {num_epochs}")
logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}")
def train_step(self, batch: dict) -> dict:
"""单步训练"""
instruction_ids = batch['instruction'].to(self.device)
response_ids = batch['response'].to(self.device)
instruction_mask = batch['instruction_mask'].to(self.device)
response_mask = batch['response_mask'].to(self.device)
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
attention_mask = torch.cat([instruction_mask, response_mask], dim=1)
batch_size , seq_len = input_ids.shape
position_ids=torch.zeros_like(input_ids)
for i in range(batch_size):
non_pad_mask = attention_mask[i].bool()
if non_pad_mask.any():
positions=torch.cumsum(non_pad_mask.long(), dim=0) -1
position_ids[i] = positions * non_pad_mask.long()
labels = input_ids.clone()
# 屏蔽 Instruction 部分
instr_len = instruction_ids.shape[1]
labels[:, :instr_len] = -100
labels[attention_mask == 0] = -100
# 准备输入数据
input_data = {
'segments': [{
'type': 'text',
'data': input_ids,
'modality_id': 0
}]
}
# 前向传播
with torch.amp.autocast('cuda', enabled=self.use_amp):
outputs = self.model(input_data, attention_mask=attention_mask,
position_ids = position_ids)
logits = outputs['logits']
# 计算损失
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
raw_loss = loss.item()
loss = loss / self.gradient_accumulation_steps
# 反向传播
self.scaler.scale(loss).backward()
return {
'loss': raw_loss
}
def optimizer_step(self):
"""优化器步骤"""
self.scaler.unscale_(self.optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
self.global_step += 1
return grad_norm.item()
@torch.no_grad()
def evaluate(self, dataloader, max_batches: int = 50) -> float:
"""评估"""
self.model.eval()
total_loss = 0.0
num_batches = 0
for i, batch in enumerate(dataloader):
if i >= max_batches:
break
if batch is None:
continue
instruction_ids = batch['instruction'].to(self.device)
response_ids = batch['response'].to(self.device)
input_ids = torch.cat([instruction_ids, response_ids], dim=1)
labels = input_ids.clone()
labels[:, :instruction_ids.shape[1]] = -100
labels[input_ids == self.tokenizer.pad_token_id] = -100
input_data = {
'segments': [{
'type': 'text',
'data': input_ids,
'modality_id': 0
}]
}
with torch.amp.autocast('cuda', enabled=self.use_amp):
outputs = self.model(input_data)
logits = outputs['logits']
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
total_loss += loss.item()
num_batches += 1
self.model.train()
return total_loss / max(num_batches, 1)
def train(
self,
train_dataloader,
eval_dataloader=None,
resume_from: Optional[str] = None
):
"""训练循环"""
logger.info("\n" + "="*80)
logger.info("Starting Post-Training (SFT)")
logger.info("="*80 + "\n")
if resume_from:
self.load_checkpoint(resume_from)
self.model.train()
for epoch in range(self.num_epochs):
logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}")
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
running_loss = 0.0
step_in_accumulation = 0
for batch_idx, batch in enumerate(progress_bar):
if batch is None:
continue
# 训练步骤
stats = self.train_step(batch)
running_loss += stats['loss']
step_in_accumulation += 1
# 优化器更新
if step_in_accumulation == self.gradient_accumulation_steps:
grad_norm = self.optimizer_step()
step_in_accumulation = 0
# 更新进度条
progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"})
# 日志
if self.global_step % self.log_interval == 0:
avg_loss = running_loss / self.log_interval
logger.info(
f"Step {self.global_step} | "
f"Epoch {epoch+1} | "
f"Loss: {avg_loss:.4f}"
)
running_loss = 0.0
# 评估
if eval_dataloader and self.global_step % self.eval_interval == 0:
eval_loss = self.evaluate(eval_dataloader)
logger.info(f"Eval Loss: {eval_loss:.4f}")
if eval_loss < self.best_eval_loss:
self.best_eval_loss = eval_loss
self.save_checkpoint(
self.checkpoint_dir / "best_model.pt",
is_best=True
)
# 保存
if self.global_step % self.save_interval == 0:
self.save_checkpoint(
self.checkpoint_dir / f"step_{self.global_step}.pt"
)
# Epoch结束评估
if eval_dataloader:
eval_loss = self.evaluate(eval_dataloader)
logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}")
logger.info("\n" + "="*80)
logger.info("Post-Training Complete!")
logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}")
logger.info("="*80 + "\n")
self.save_checkpoint(self.checkpoint_dir / "final_model.pt")
def save_checkpoint(self, path: Path, is_best: bool = False):
"""保存checkpoint"""
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
'global_step': self.global_step,
'best_eval_loss': self.best_eval_loss,
'timestamp': datetime.now().isoformat()
}
torch.save(checkpoint, path)
logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else ""))
def load_checkpoint(self, path: str):
"""加载checkpoint"""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if self.use_amp and checkpoint.get('scaler_state_dict'):
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.global_step = checkpoint['global_step']
self.best_eval_loss = checkpoint['best_eval_loss']
logger.info(f"Checkpoint loaded from {path}")
def main():
"""主函数"""
# 配置
config = {
# 模型配置
'model_dim': 1536,
'vocab_size': 151665,
'n_layers': 12,
'n_heads': 12,
'n_kv_heads': 4,
'max_seq_len': 512,
'dropout': 0.0,
'use_moe': False,
# 训练配置
'batch_size': 2,
'gradient_accumulation_steps': 8,
'learning_rate': 1e-5,
'weight_decay': 0.01,
'num_epochs': 3,
'max_grad_norm': 1.0,
# 数据配置
'data_mix': 'simple_instruct',
'max_samples_train': 20000,
'max_samples_eval': 1000,
'max_length': 512,
'num_workers': 4,
# RLHF配置
'do_rlhf': False,
'preference_dataset': 'hh_rlhf',
'grpo_iterations': 3,
'grpo_kl_coef': 0.04,
'grpo_group_size': 4,
# 路径
'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt',
'checkpoint_dir': 'checkpoints/posttrain',
'log_interval': 50,
'eval_interval': 500,
'save_interval': 1000,
}
logger.info("Configuration:")
logger.info(json.dumps(config, indent=2))
# 初始化tokenizer
logger.info("\nInitializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-7B-Instruct",
use_fast=True,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
config['vocab_size'] = len(tokenizer)
# 初始化或加载模型
logger.info("\nInitializing model...")
model = MultiModalDenseTransformer(
model_dim=config['model_dim'],
vocab_size=config['vocab_size'],
n_layers=config['n_layers'],
n_heads=config['n_heads'],
n_kv_heads=config['n_kv_heads'],
max_seq_len=config['max_seq_len'],
dropout=config['dropout'],
use_moe=config['use_moe'],
use_gradient_checkpointing=False,
rope_scaling_type="yarn",
use_multimodal_fusion=False,
use_contrastive=False
)
if config['pretrain_checkpoint']:
logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}")
checkpoint = torch.load(config['pretrain_checkpoint'])
model.load_state_dict(checkpoint['model_state_dict'])
logger.info("\n" + "="*80)
logger.info("PHASE 1: Supervised Fine-Tuning")
logger.info("="*80)
# 创建数据加载器
train_dataloader = create_posttrain_dataloader(
mix_name=config['data_mix'],
tokenizer=tokenizer,
batch_size=config['batch_size'],
num_workers=config['num_workers'],
max_length=config['max_length'],
max_samples=config['max_samples_train'],
split='train',
shuffle=True
)
eval_dataloader = create_posttrain_dataloader(
mix_name=config['data_mix'],
tokenizer=tokenizer,
batch_size=config['batch_size'] * 2,
num_workers=config['num_workers'],
max_length=config['max_length'],
max_samples=config['max_samples_eval'],
split='train', # 使用train的后部分作为验证
shuffle=False
)
# 创建训练器
trainer = PostTrainer(
model=model,
tokenizer=tokenizer,
learning_rate=config['learning_rate'],
weight_decay=config['weight_decay'],
num_epochs=config['num_epochs'],
gradient_accumulation_steps=config['gradient_accumulation_steps'],
max_grad_norm=config['max_grad_norm'],
log_interval=config['log_interval'],
eval_interval=config['eval_interval'],
save_interval=config['save_interval'],
checkpoint_dir=config['checkpoint_dir']
)
trainer.train(train_dataloader, eval_dataloader)
if config['do_rlhf']:
logger.info("\n" + "="*80)
logger.info("PHASE 2: RLHF with GRPO")
logger.info("="*80)
try:
# 训练奖励模型
logger.info("\nTraining Reward Model...")
reward_base_model = copy.deepcopy(model)
reward_model = RewardModel(reward_base_model, use_value_head=True)
preference_dataloader = create_preference_dataloader(
dataset_name=config['preference_dataset'],
tokenizer=tokenizer,
batch_size=config['batch_size'],
num_workers=config['num_workers'],
max_samples=5000,
split='train'
)
reward_trainer = RewardModelTrainer(
reward_model=reward_model,
learning_rate=1e-5
)
reward_trainer.train(preference_dataloader, num_epochs=1)
# GRPO训练
logger.info("\nStarting GRPO Training...")
ref_model = copy.deepcopy(model)
ref_model.eval()
grpo_trainer = GRPOTrainer(
actor_model=model,
reward_model=reward_model,
ref_model=ref_model,
tokenizer=tokenizer,
learning_rate=1e-6,
kl_coef=config['grpo_kl_coef'],
group_size=config['grpo_group_size'],
update_batch_size=2,
use_amp=True
)
# 准备prompts
prompt_dataloader = create_posttrain_dataloader(
mix_name=config['data_mix'],
tokenizer=tokenizer,
batch_size=4,
num_workers=2,
max_samples=1000,
split='train'
)
# 提取prompts
prompts = []
for batch in prompt_dataloader:
if batch and batch.get('instruction') is not None:
prompts.append(batch['instruction'])
if len(prompts) >= 200:
break
if prompts:
prompt_tensor = torch.cat(prompts[:200], dim=0)
from torch.utils.data import TensorDataset, DataLoader
prompt_loader = DataLoader(
TensorDataset(prompt_tensor),
batch_size=4
)
grpo_trainer.train(
prompt_loader,
num_iterations=config['grpo_iterations'],
max_gen_len=50,
save_path=config['checkpoint_dir'] + "/grpo"
)
except Exception as e:
logger.error(f"Error in RLHF: {e}")
import traceback
traceback.print_exc()
logger.info("\n" + "="*80)
logger.info("All Training Complete!")
logger.info("="*80)
if __name__ == "__main__":
main()