import os import torch import torch.nn.functional as F from transformers import AutoTokenizer from pathlib import Path import logging from tqdm import tqdm import json from datetime import datetime import copy from model import MultiModalDenseTransformer from data_loader import ( create_posttrain_dataloader, create_preference_dataloader ) from data_config import POSTTRAIN_MIX from reward_model import RewardModel, RewardModelTrainer from grpo import GRPOTrainer from typing import Optional logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" class PostTrainer: def __init__( self, model: MultiModalDenseTransformer, tokenizer, learning_rate: float = 1e-5, weight_decay: float = 0.01, num_epochs: int = 3, gradient_accumulation_steps: int = 1, max_grad_norm: float = 1.0, log_interval: int = 10, eval_interval: int = 500, save_interval: int = 1000, checkpoint_dir: str = "checkpoints/posttrain" ): self.model = model self.tokenizer = tokenizer self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) # 优化器 self.optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.95), eps=1e-8 ) # 混合精度 self.use_amp = torch.cuda.is_available() self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp) # 训练参数 self.num_epochs = num_epochs self.gradient_accumulation_steps = gradient_accumulation_steps self.max_grad_norm = max_grad_norm self.log_interval = log_interval self.eval_interval = eval_interval self.save_interval = save_interval # Checkpoint管理 self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) # 训练状态 self.global_step = 0 self.best_eval_loss = float('inf') logger.info(f"PostTrainer initialized:") logger.info(f" Device: {self.device}") logger.info(f" Learning Rate: {learning_rate}") logger.info(f" Num Epochs: {num_epochs}") logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}") def train_step(self, batch: dict) -> dict: """单步训练""" instruction_ids = batch['instruction'].to(self.device) response_ids = batch['response'].to(self.device) instruction_mask = batch['instruction_mask'].to(self.device) response_mask = batch['response_mask'].to(self.device) input_ids = torch.cat([instruction_ids, response_ids], dim=1) attention_mask = torch.cat([instruction_mask, response_mask], dim=1) batch_size , seq_len = input_ids.shape position_ids=torch.zeros_like(input_ids) for i in range(batch_size): non_pad_mask = attention_mask[i].bool() if non_pad_mask.any(): positions=torch.cumsum(non_pad_mask.long(), dim=0) -1 position_ids[i] = positions * non_pad_mask.long() labels = input_ids.clone() # 屏蔽 Instruction 部分 instr_len = instruction_ids.shape[1] labels[:, :instr_len] = -100 labels[attention_mask == 0] = -100 # 准备输入数据 input_data = { 'segments': [{ 'type': 'text', 'data': input_ids, 'modality_id': 0 }] } # 前向传播 with torch.amp.autocast('cuda', enabled=self.use_amp): outputs = self.model(input_data, attention_mask=attention_mask, position_ids = position_ids) logits = outputs['logits'] # 计算损失 shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100 ) raw_loss = loss.item() loss = loss / self.gradient_accumulation_steps # 反向传播 self.scaler.scale(loss).backward() return { 'loss': raw_loss } def optimizer_step(self): """优化器步骤""" self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm ) self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad(set_to_none=True) self.global_step += 1 return grad_norm.item() @torch.no_grad() def evaluate(self, dataloader, max_batches: int = 50) -> float: """评估""" self.model.eval() total_loss = 0.0 num_batches = 0 for i, batch in enumerate(dataloader): if i >= max_batches: break if batch is None: continue instruction_ids = batch['instruction'].to(self.device) response_ids = batch['response'].to(self.device) input_ids = torch.cat([instruction_ids, response_ids], dim=1) labels = input_ids.clone() labels[:, :instruction_ids.shape[1]] = -100 labels[input_ids == self.tokenizer.pad_token_id] = -100 input_data = { 'segments': [{ 'type': 'text', 'data': input_ids, 'modality_id': 0 }] } with torch.amp.autocast('cuda', enabled=self.use_amp): outputs = self.model(input_data) logits = outputs['logits'] shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100 ) total_loss += loss.item() num_batches += 1 self.model.train() return total_loss / max(num_batches, 1) def train( self, train_dataloader, eval_dataloader=None, resume_from: Optional[str] = None ): """训练循环""" logger.info("\n" + "="*80) logger.info("Starting Post-Training (SFT)") logger.info("="*80 + "\n") if resume_from: self.load_checkpoint(resume_from) self.model.train() for epoch in range(self.num_epochs): logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}") progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}") running_loss = 0.0 step_in_accumulation = 0 for batch_idx, batch in enumerate(progress_bar): if batch is None: continue # 训练步骤 stats = self.train_step(batch) running_loss += stats['loss'] step_in_accumulation += 1 # 优化器更新 if step_in_accumulation == self.gradient_accumulation_steps: grad_norm = self.optimizer_step() step_in_accumulation = 0 # 更新进度条 progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"}) # 日志 if self.global_step % self.log_interval == 0: avg_loss = running_loss / self.log_interval logger.info( f"Step {self.global_step} | " f"Epoch {epoch+1} | " f"Loss: {avg_loss:.4f}" ) running_loss = 0.0 # 评估 if eval_dataloader and self.global_step % self.eval_interval == 0: eval_loss = self.evaluate(eval_dataloader) logger.info(f"Eval Loss: {eval_loss:.4f}") if eval_loss < self.best_eval_loss: self.best_eval_loss = eval_loss self.save_checkpoint( self.checkpoint_dir / "best_model.pt", is_best=True ) # 保存 if self.global_step % self.save_interval == 0: self.save_checkpoint( self.checkpoint_dir / f"step_{self.global_step}.pt" ) # Epoch结束评估 if eval_dataloader: eval_loss = self.evaluate(eval_dataloader) logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}") logger.info("\n" + "="*80) logger.info("Post-Training Complete!") logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}") logger.info("="*80 + "\n") self.save_checkpoint(self.checkpoint_dir / "final_model.pt") def save_checkpoint(self, path: Path, is_best: bool = False): """保存checkpoint""" checkpoint = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None, 'global_step': self.global_step, 'best_eval_loss': self.best_eval_loss, 'timestamp': datetime.now().isoformat() } torch.save(checkpoint, path) logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else "")) def load_checkpoint(self, path: str): """加载checkpoint""" checkpoint = torch.load(path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if self.use_amp and checkpoint.get('scaler_state_dict'): self.scaler.load_state_dict(checkpoint['scaler_state_dict']) self.global_step = checkpoint['global_step'] self.best_eval_loss = checkpoint['best_eval_loss'] logger.info(f"Checkpoint loaded from {path}") def main(): """主函数""" # 配置 config = { # 模型配置 'model_dim': 1536, 'vocab_size': 151665, 'n_layers': 12, 'n_heads': 12, 'n_kv_heads': 4, 'max_seq_len': 512, 'dropout': 0.0, 'use_moe': False, # 训练配置 'batch_size': 2, 'gradient_accumulation_steps': 8, 'learning_rate': 1e-5, 'weight_decay': 0.01, 'num_epochs': 3, 'max_grad_norm': 1.0, # 数据配置 'data_mix': 'simple_instruct', 'max_samples_train': 20000, 'max_samples_eval': 1000, 'max_length': 512, 'num_workers': 4, # RLHF配置 'do_rlhf': False, 'preference_dataset': 'hh_rlhf', 'grpo_iterations': 3, 'grpo_kl_coef': 0.04, 'grpo_group_size': 4, # 路径 'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt', 'checkpoint_dir': 'checkpoints/posttrain', 'log_interval': 50, 'eval_interval': 500, 'save_interval': 1000, } logger.info("Configuration:") logger.info(json.dumps(config, indent=2)) # 初始化tokenizer logger.info("\nInitializing tokenizer...") tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen2.5-7B-Instruct", use_fast=True, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id config['vocab_size'] = len(tokenizer) # 初始化或加载模型 logger.info("\nInitializing model...") model = MultiModalDenseTransformer( model_dim=config['model_dim'], vocab_size=config['vocab_size'], n_layers=config['n_layers'], n_heads=config['n_heads'], n_kv_heads=config['n_kv_heads'], max_seq_len=config['max_seq_len'], dropout=config['dropout'], use_moe=config['use_moe'], use_gradient_checkpointing=False, rope_scaling_type="yarn", use_multimodal_fusion=False, use_contrastive=False ) if config['pretrain_checkpoint']: logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}") checkpoint = torch.load(config['pretrain_checkpoint']) model.load_state_dict(checkpoint['model_state_dict']) logger.info("\n" + "="*80) logger.info("PHASE 1: Supervised Fine-Tuning") logger.info("="*80) # 创建数据加载器 train_dataloader = create_posttrain_dataloader( mix_name=config['data_mix'], tokenizer=tokenizer, batch_size=config['batch_size'], num_workers=config['num_workers'], max_length=config['max_length'], max_samples=config['max_samples_train'], split='train', shuffle=True ) eval_dataloader = create_posttrain_dataloader( mix_name=config['data_mix'], tokenizer=tokenizer, batch_size=config['batch_size'] * 2, num_workers=config['num_workers'], max_length=config['max_length'], max_samples=config['max_samples_eval'], split='train', # 使用train的后部分作为验证 shuffle=False ) # 创建训练器 trainer = PostTrainer( model=model, tokenizer=tokenizer, learning_rate=config['learning_rate'], weight_decay=config['weight_decay'], num_epochs=config['num_epochs'], gradient_accumulation_steps=config['gradient_accumulation_steps'], max_grad_norm=config['max_grad_norm'], log_interval=config['log_interval'], eval_interval=config['eval_interval'], save_interval=config['save_interval'], checkpoint_dir=config['checkpoint_dir'] ) trainer.train(train_dataloader, eval_dataloader) if config['do_rlhf']: logger.info("\n" + "="*80) logger.info("PHASE 2: RLHF with GRPO") logger.info("="*80) try: # 训练奖励模型 logger.info("\nTraining Reward Model...") reward_base_model = copy.deepcopy(model) reward_model = RewardModel(reward_base_model, use_value_head=True) preference_dataloader = create_preference_dataloader( dataset_name=config['preference_dataset'], tokenizer=tokenizer, batch_size=config['batch_size'], num_workers=config['num_workers'], max_samples=5000, split='train' ) reward_trainer = RewardModelTrainer( reward_model=reward_model, learning_rate=1e-5 ) reward_trainer.train(preference_dataloader, num_epochs=1) # GRPO训练 logger.info("\nStarting GRPO Training...") ref_model = copy.deepcopy(model) ref_model.eval() grpo_trainer = GRPOTrainer( actor_model=model, reward_model=reward_model, ref_model=ref_model, tokenizer=tokenizer, learning_rate=1e-6, kl_coef=config['grpo_kl_coef'], group_size=config['grpo_group_size'], update_batch_size=2, use_amp=True ) # 准备prompts prompt_dataloader = create_posttrain_dataloader( mix_name=config['data_mix'], tokenizer=tokenizer, batch_size=4, num_workers=2, max_samples=1000, split='train' ) # 提取prompts prompts = [] for batch in prompt_dataloader: if batch and batch.get('instruction') is not None: prompts.append(batch['instruction']) if len(prompts) >= 200: break if prompts: prompt_tensor = torch.cat(prompts[:200], dim=0) from torch.utils.data import TensorDataset, DataLoader prompt_loader = DataLoader( TensorDataset(prompt_tensor), batch_size=4 ) grpo_trainer.train( prompt_loader, num_iterations=config['grpo_iterations'], max_gen_len=50, save_path=config['checkpoint_dir'] + "/grpo" ) except Exception as e: logger.error(f"Error in RLHF: {e}") import traceback traceback.print_exc() logger.info("\n" + "="*80) logger.info("All Training Complete!") logger.info("="*80) if __name__ == "__main__": main()