import os import torch import torch.nn.functional as F from transformers import AutoTokenizer from pathlib import Path import logging from tqdm import tqdm import json from datetime import datetime from model import MultiModalDenseTransformer from data_loader import create_pretrain_dataloader logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" class PreTrainer: def __init__( self, model: MultiModalDenseTransformer, tokenizer, learning_rate: float = 3e-4, weight_decay: float = 0.1, warmup_steps: int = 1000, max_steps: int = 100000, gradient_accumulation_steps: int = 16, max_grad_norm: float = 1.0, log_interval: int = 10, save_interval: int = 1000, checkpoint_dir: str = "checkpoints/pretrain", loss_log_file: str = "checkpoints/pretrain/train_loss.log" ): self.model = model self.tokenizer = tokenizer self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) self.optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.95), eps=1e-8 ) from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts self.warmup_steps = warmup_steps self.max_lr = learning_rate self.min_lr = learning_rate * 0.1 self.current_step = 0 # 混合精度 self.use_amp = torch.cuda.is_available() self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp) # 训练参数 self.gradient_accumulation_steps = gradient_accumulation_steps self.max_grad_norm = max_grad_norm self.max_steps = max_steps self.log_interval = log_interval self.save_interval = save_interval # Checkpoint管理 self.checkpoint_dir = Path(checkpoint_dir) self.checkpoint_dir.mkdir(parents=True, exist_ok=True) # 损失日志 self.loss_log_file = Path(loss_log_file) self.loss_log_file.parent.mkdir(parents=True, exist_ok=True) # 训练状态 self.global_step = 0 self.tokens_seen = 0 self.running_loss = 0.0 self.best_loss = float('inf') logger.info(f"PreTrainer initialized:") logger.info(f" Device: {self.device}") logger.info(f" Learning Rate: {learning_rate}") logger.info(f" Max Steps: {max_steps}") logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}") logger.info(f" Effective Batch Size: {gradient_accumulation_steps}") logger.info(f" Mixed Precision: {self.use_amp}") def _get_lr(self) -> float: """手动计算学习率(Warmup + Cosine)""" if self.current_step < self.warmup_steps: # Linear warmup return self.max_lr * (self.current_step / self.warmup_steps) else: # Cosine decay progress = (self.current_step - self.warmup_steps) / (self.max_steps - self.warmup_steps) return self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159))) def _set_lr(self, lr: float): """设置学习率""" for param_group in self.optimizer.param_groups: param_group['lr'] = lr def train_step(self, batch: dict) -> dict: input_ids = batch['input_ids'].to(self.device) attention_mask = batch['attention_mask'].to(self.device) batch_size, seq_len = input_ids.shape position_ids= torch.zeros_like(input_ids) for i in range(batch_size): non_pad_mask = attention_mask[i].bool() if non_pad_mask.any(): positions = torch.cumsum(non_pad_mask.long(), dim=0) -1 position_ids[i]=positions * non_pad_mask.long() # 准备输入 input_data = { 'segments': [{ 'type': 'text', 'data': input_ids, 'modality_id': 0 }] } # 前向传播 with torch.amp.autocast('cuda', enabled=self.use_amp): outputs = self.model( input_data, attention_mask=attention_mask, position_ids=position_ids) logits = outputs['logits'] # 计算损失(标准自回归) shift_logits = logits[:, :-1, :].contiguous() shift_labels = input_ids[:, 1:].contiguous() shift_attention_mask = attention_mask[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction='none' ) # 应用mask loss = (loss * shift_attention_mask.view(-1)).sum() / (shift_attention_mask.sum() + 1e-8) loss_for_backward = loss / self.gradient_accumulation_steps self.scaler.scale(loss_for_backward).backward() self.tokens_seen += attention_mask.sum().item() return { 'loss': loss.item(), # 返回真实的、未缩放的loss 'lr': self.optimizer.param_groups[0]['lr'] } def optimizer_step(self): """优化器步骤""" # Unscale梯度 self.scaler.unscale_(self.optimizer) # 梯度裁剪 grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm ) # 更新参数 self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad(set_to_none=True) # 更新学习率 self.current_step += 1 self.global_step += 1 lr = self._get_lr() self._set_lr(lr) return grad_norm.item() def _write_loss_to_txt(self, step, avg_loss, lr, tokens_seen): """写入损失日志""" log_content = ( f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"Step: {step}/{self.max_steps}, " f"Average Loss: {avg_loss:.4f}, " f"Learning Rate: {lr:.2e}, " f"Tokens Seen: {tokens_seen/1e9:.2f}B\n" ) with open(self.loss_log_file, 'a', encoding='utf-8') as f: f.write(log_content) def train(self, dataloader, resume_from=None): """训练循环""" logger.info("\n" + "="*80) logger.info("Starting Pre-Training (Fixed Version)") logger.info("="*80 + "\n") # 恢复训练 if resume_from: self.load_checkpoint(resume_from) # 初始化日志 if not self.loss_log_file.exists(): with open(self.loss_log_file, 'w', encoding='utf-8') as f: f.write(" Fixed Training Log (Real Loss Values)\n") f.write("="*80 + "\n") self.model.train() progress_bar = tqdm(total=self.max_steps, initial=self.global_step) step_in_accumulation = 0 accumulated_loss = 0.0 batches_to_skip = self.global_step * self.gradient_accumulation_steps logger.info(f"Current Global Step: {self.global_step}") if batches_to_skip > 0: logger.info(f" Resuming: Need to skip {batches_to_skip} batches to restore data state...") logger.info("This might take a while depending on network/disk speed...") # 创建迭代器 data_iterator = iter(dataloader) skipped = 0 if batches_to_skip > 0: with tqdm(total=batches_to_skip, desc="Skipping trained batches", unit="batch") as skip_pbar: while skipped < batches_to_skip: try: # 只取数据,不进模型,不计算梯度 _ = next(data_iterator) skipped += 1 skip_pbar.update(1) except StopIteration: logger.error("Dataset exhausted during skipping! Check your dataset size or max_steps.") return logger.info(" Data fast-forward complete. Resuming training...") try: while True: try: batch = next(data_iterator) except StopIteration: break if batch is None or batch['input_ids'].size(0) == 0: continue stats = self.train_step(batch) step_in_accumulation += 1 accumulated_loss += stats['loss'] if step_in_accumulation >= self.gradient_accumulation_steps: avg_step_loss = accumulated_loss / self.gradient_accumulation_steps grad_norm = self.optimizer_step() stats['grad_norm'] = grad_norm stats['loss'] = avg_step_loss self.running_loss += avg_step_loss step_in_accumulation = 0 accumulated_loss = 0.0 progress_bar.update(1) progress_bar.set_postfix({ 'loss': f"{stats['loss']:.4f}", 'lr': f"{stats['lr']:.2e}", 'tokens': f"{self.tokens_seen/1e9:.2f}B", 'grad': f"{grad_norm:.2f}" }) # 日志记录 if self.global_step % self.log_interval == 0: avg_loss = self.running_loss / self.log_interval logger.info( f"Step {self.global_step}/{self.max_steps} | " f"Loss: {avg_loss:.4f} | " f"LR: {stats['lr']:.2e} | " f"GradNorm: {grad_norm:.2f} | " f"Tokens: {self.tokens_seen/1e9:.2f}B" ) if avg_loss < self.best_loss: self.best_loss = avg_loss logger.info(f" New best loss: {self.best_loss:.4f}") self._write_loss_to_txt( step=self.global_step, avg_loss=avg_loss, lr=stats['lr'], tokens_seen=self.tokens_seen ) self.running_loss = 0.0 # 保存checkpoint if self.global_step % self.save_interval == 0: self.save_checkpoint( self.checkpoint_dir / f"step_{self.global_step}.pt" ) # 完成训练 if self.global_step >= self.max_steps: break except KeyboardInterrupt: self.save_checkpoint( self.checkpoint_dir / f"interrupted_step_{self.global_step}.pt" ) finally: progress_bar.close() logger.info("\n" + "="*80) logger.info("Pre-Training Complete!") logger.info(f" Total Steps: {self.global_step}") logger.info(f" Total Tokens: {self.tokens_seen/1e9:.2f}B") logger.info(f" Best Loss: {self.best_loss:.4f}") logger.info("="*80 + "\n") # 保存最终模型 self.save_checkpoint(self.checkpoint_dir / "final_model.pt") def save_checkpoint(self, path: Path): """保存checkpoint""" checkpoint = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None, 'global_step': self.global_step, 'current_step': self.current_step, 'tokens_seen': self.tokens_seen, 'best_loss': self.best_loss, 'timestamp': datetime.now().isoformat() } torch.save(checkpoint, path) logger.info(f" Checkpoint saved to {path}") def load_checkpoint(self, path: str): """加载checkpoint""" checkpoint = torch.load(path, map_location=self.device, weights_only=True) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if self.use_amp and checkpoint.get('scaler_state_dict'): self.scaler.load_state_dict(checkpoint['scaler_state_dict']) self.global_step = checkpoint['global_step'] self.current_step = checkpoint.get('current_step', self.global_step) self.tokens_seen = checkpoint['tokens_seen'] self.best_loss = checkpoint.get('best_loss', float('inf')) logger.info(f" Checkpoint loaded from {path}") logger.info(f" Resuming from step {self.global_step}") logger.info(f" Tokens seen: {self.tokens_seen/1e9:.2f}B") def main(): config = { # 模型配置 'model_dim': 1536, 'vocab_size': 151665, 'n_layers': 12, 'n_heads': 12, 'n_kv_heads': 4, 'max_seq_len': 512, 'dropout': 0.1, 'use_moe': False, 'batch_size': 4, 'gradient_accumulation_steps': 8, 'learning_rate': 3e-4, 'weight_decay': 0.1, 'warmup_steps': 500, 'max_steps': 10000, 'max_grad_norm': 1.0, # 数据配置 'data_mix': 'text_only', 'max_length': 512, 'num_workers': 2, # 日志和保存 'log_interval': 10, 'save_interval': 500, 'checkpoint_dir': 'checkpoints/pretrain_fixed', 'loss_log_file': 'checkpoints/pretrain_fixed/train_loss.log' } logger.info("="*80) logger.info(json.dumps(config, indent=2)) logger.info("="*80 + "\n") # 初始化tokenizer logger.info("Initializing tokenizer...") tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen2.5-7B-Instruct", use_fast=True, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id config['vocab_size'] = len(tokenizer) logger.info(f"Vocab size: {config['vocab_size']}\n") # 初始化模型 logger.info("Initializing model...") model = MultiModalDenseTransformer( model_dim=config['model_dim'], vocab_size=config['vocab_size'], n_layers=config['n_layers'], n_heads=config['n_heads'], n_kv_heads=config['n_kv_heads'], max_seq_len=config['max_seq_len'], dropout=config['dropout'], use_moe=config['use_moe'], use_gradient_checkpointing=True, rope_scaling_type="yarn", use_multimodal_fusion=False, use_contrastive=False ) # 创建数据加载器 logger.info(f"\nCreating dataloader (mix: {config['data_mix']})...") dataloader = create_pretrain_dataloader( mix_name=config['data_mix'], tokenizer=tokenizer, batch_size=config['batch_size'], num_workers=config['num_workers'], max_length=config['max_length'] ) # 创建训练器 trainer = PreTrainer( model=model, tokenizer=tokenizer, learning_rate=config['learning_rate'], weight_decay=config['weight_decay'], warmup_steps=config['warmup_steps'], max_steps=config['max_steps'], gradient_accumulation_steps=config['gradient_accumulation_steps'], max_grad_norm=config['max_grad_norm'], log_interval=config['log_interval'], save_interval=config['save_interval'], checkpoint_dir=config['checkpoint_dir'], loss_log_file=config['loss_log_file'] ) logger.info("\n Starting fresh training with fixes...\n") trainer.train(dataloader, resume_from="/root/step_6500.pt") #trainer.train(dataloader) if __name__ == "__main__": main()