|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoTokenizer |
|
|
from pathlib import Path |
|
|
import logging |
|
|
from tqdm import tqdm |
|
|
import json |
|
|
from datetime import datetime |
|
|
import copy |
|
|
from model import MultiModalDenseTransformer |
|
|
|
|
|
from data_loader import ( |
|
|
create_posttrain_dataloader, |
|
|
create_preference_dataloader |
|
|
) |
|
|
from data_config import POSTTRAIN_MIX |
|
|
from reward_model import RewardModel, RewardModelTrainer |
|
|
from grpo import GRPOTrainer |
|
|
from typing import Optional |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
|
|
|
|
|
class PostTrainer: |
|
|
def __init__( |
|
|
self, |
|
|
model: MultiModalDenseTransformer, |
|
|
tokenizer, |
|
|
learning_rate: float = 1e-5, |
|
|
weight_decay: float = 0.01, |
|
|
num_epochs: int = 3, |
|
|
gradient_accumulation_steps: int = 1, |
|
|
max_grad_norm: float = 1.0, |
|
|
log_interval: int = 10, |
|
|
eval_interval: int = 500, |
|
|
save_interval: int = 1000, |
|
|
checkpoint_dir: str = "checkpoints/posttrain" |
|
|
): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW( |
|
|
model.parameters(), |
|
|
lr=learning_rate, |
|
|
weight_decay=weight_decay, |
|
|
betas=(0.9, 0.95), |
|
|
eps=1e-8 |
|
|
) |
|
|
|
|
|
|
|
|
self.use_amp = torch.cuda.is_available() |
|
|
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp) |
|
|
|
|
|
|
|
|
self.num_epochs = num_epochs |
|
|
self.gradient_accumulation_steps = gradient_accumulation_steps |
|
|
self.max_grad_norm = max_grad_norm |
|
|
self.log_interval = log_interval |
|
|
self.eval_interval = eval_interval |
|
|
self.save_interval = save_interval |
|
|
|
|
|
|
|
|
self.checkpoint_dir = Path(checkpoint_dir) |
|
|
self.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.global_step = 0 |
|
|
self.best_eval_loss = float('inf') |
|
|
|
|
|
logger.info(f"PostTrainer initialized:") |
|
|
logger.info(f" Device: {self.device}") |
|
|
logger.info(f" Learning Rate: {learning_rate}") |
|
|
logger.info(f" Num Epochs: {num_epochs}") |
|
|
logger.info(f" Gradient Accumulation: {gradient_accumulation_steps}") |
|
|
|
|
|
def train_step(self, batch: dict) -> dict: |
|
|
"""单步训练""" |
|
|
instruction_ids = batch['instruction'].to(self.device) |
|
|
response_ids = batch['response'].to(self.device) |
|
|
|
|
|
instruction_mask = batch['instruction_mask'].to(self.device) |
|
|
response_mask = batch['response_mask'].to(self.device) |
|
|
|
|
|
input_ids = torch.cat([instruction_ids, response_ids], dim=1) |
|
|
attention_mask = torch.cat([instruction_mask, response_mask], dim=1) |
|
|
|
|
|
batch_size , seq_len = input_ids.shape |
|
|
position_ids=torch.zeros_like(input_ids) |
|
|
|
|
|
for i in range(batch_size): |
|
|
non_pad_mask = attention_mask[i].bool() |
|
|
if non_pad_mask.any(): |
|
|
positions=torch.cumsum(non_pad_mask.long(), dim=0) -1 |
|
|
position_ids[i] = positions * non_pad_mask.long() |
|
|
labels = input_ids.clone() |
|
|
|
|
|
|
|
|
instr_len = instruction_ids.shape[1] |
|
|
labels[:, :instr_len] = -100 |
|
|
|
|
|
labels[attention_mask == 0] = -100 |
|
|
|
|
|
|
|
|
|
|
|
input_data = { |
|
|
'segments': [{ |
|
|
'type': 'text', |
|
|
'data': input_ids, |
|
|
'modality_id': 0 |
|
|
}] |
|
|
} |
|
|
|
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp): |
|
|
outputs = self.model(input_data, attention_mask=attention_mask, |
|
|
position_ids = position_ids) |
|
|
|
|
|
logits = outputs['logits'] |
|
|
|
|
|
|
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
|
shift_labels = labels[:, 1:].contiguous() |
|
|
|
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100 |
|
|
) |
|
|
raw_loss = loss.item() |
|
|
loss = loss / self.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
self.scaler.scale(loss).backward() |
|
|
|
|
|
return { |
|
|
'loss': raw_loss |
|
|
} |
|
|
|
|
|
def optimizer_step(self): |
|
|
"""优化器步骤""" |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_( |
|
|
self.model.parameters(), |
|
|
self.max_grad_norm |
|
|
) |
|
|
|
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
self.optimizer.zero_grad(set_to_none=True) |
|
|
self.global_step += 1 |
|
|
return grad_norm.item() |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(self, dataloader, max_batches: int = 50) -> float: |
|
|
"""评估""" |
|
|
self.model.eval() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
for i, batch in enumerate(dataloader): |
|
|
if i >= max_batches: |
|
|
break |
|
|
|
|
|
if batch is None: |
|
|
continue |
|
|
|
|
|
instruction_ids = batch['instruction'].to(self.device) |
|
|
response_ids = batch['response'].to(self.device) |
|
|
input_ids = torch.cat([instruction_ids, response_ids], dim=1) |
|
|
|
|
|
labels = input_ids.clone() |
|
|
labels[:, :instruction_ids.shape[1]] = -100 |
|
|
labels[input_ids == self.tokenizer.pad_token_id] = -100 |
|
|
|
|
|
input_data = { |
|
|
'segments': [{ |
|
|
'type': 'text', |
|
|
'data': input_ids, |
|
|
'modality_id': 0 |
|
|
}] |
|
|
} |
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp): |
|
|
outputs = self.model(input_data) |
|
|
logits = outputs['logits'] |
|
|
|
|
|
shift_logits = logits[:, :-1, :].contiguous() |
|
|
shift_labels = labels[:, 1:].contiguous() |
|
|
|
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100 |
|
|
) |
|
|
|
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
self.model.train() |
|
|
return total_loss / max(num_batches, 1) |
|
|
|
|
|
def train( |
|
|
self, |
|
|
train_dataloader, |
|
|
eval_dataloader=None, |
|
|
resume_from: Optional[str] = None |
|
|
): |
|
|
"""训练循环""" |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("Starting Post-Training (SFT)") |
|
|
logger.info("="*80 + "\n") |
|
|
|
|
|
if resume_from: |
|
|
self.load_checkpoint(resume_from) |
|
|
|
|
|
self.model.train() |
|
|
|
|
|
for epoch in range(self.num_epochs): |
|
|
logger.info(f"\nEpoch {epoch+1}/{self.num_epochs}") |
|
|
|
|
|
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}") |
|
|
running_loss = 0.0 |
|
|
step_in_accumulation = 0 |
|
|
|
|
|
for batch_idx, batch in enumerate(progress_bar): |
|
|
if batch is None: |
|
|
continue |
|
|
|
|
|
|
|
|
stats = self.train_step(batch) |
|
|
running_loss += stats['loss'] |
|
|
step_in_accumulation += 1 |
|
|
|
|
|
|
|
|
if step_in_accumulation == self.gradient_accumulation_steps: |
|
|
grad_norm = self.optimizer_step() |
|
|
step_in_accumulation = 0 |
|
|
|
|
|
|
|
|
progress_bar.set_postfix({'loss': f"{stats['loss']:.4f}"}) |
|
|
|
|
|
|
|
|
if self.global_step % self.log_interval == 0: |
|
|
avg_loss = running_loss / self.log_interval |
|
|
logger.info( |
|
|
f"Step {self.global_step} | " |
|
|
f"Epoch {epoch+1} | " |
|
|
f"Loss: {avg_loss:.4f}" |
|
|
) |
|
|
running_loss = 0.0 |
|
|
|
|
|
|
|
|
if eval_dataloader and self.global_step % self.eval_interval == 0: |
|
|
eval_loss = self.evaluate(eval_dataloader) |
|
|
logger.info(f"Eval Loss: {eval_loss:.4f}") |
|
|
|
|
|
if eval_loss < self.best_eval_loss: |
|
|
self.best_eval_loss = eval_loss |
|
|
self.save_checkpoint( |
|
|
self.checkpoint_dir / "best_model.pt", |
|
|
is_best=True |
|
|
) |
|
|
|
|
|
|
|
|
if self.global_step % self.save_interval == 0: |
|
|
self.save_checkpoint( |
|
|
self.checkpoint_dir / f"step_{self.global_step}.pt" |
|
|
) |
|
|
|
|
|
|
|
|
if eval_dataloader: |
|
|
eval_loss = self.evaluate(eval_dataloader) |
|
|
logger.info(f"\nEpoch {epoch+1} Eval Loss: {eval_loss:.4f}") |
|
|
|
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("Post-Training Complete!") |
|
|
logger.info(f" Best Eval Loss: {self.best_eval_loss:.4f}") |
|
|
logger.info("="*80 + "\n") |
|
|
|
|
|
self.save_checkpoint(self.checkpoint_dir / "final_model.pt") |
|
|
|
|
|
def save_checkpoint(self, path: Path, is_best: bool = False): |
|
|
"""保存checkpoint""" |
|
|
checkpoint = { |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None, |
|
|
'global_step': self.global_step, |
|
|
'best_eval_loss': self.best_eval_loss, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
torch.save(checkpoint, path) |
|
|
logger.info(f"Checkpoint saved to {path}" + (" (BEST)" if is_best else "")) |
|
|
|
|
|
def load_checkpoint(self, path: str): |
|
|
"""加载checkpoint""" |
|
|
checkpoint = torch.load(path, map_location=self.device) |
|
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
if self.use_amp and checkpoint.get('scaler_state_dict'): |
|
|
self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
|
|
|
|
|
self.global_step = checkpoint['global_step'] |
|
|
self.best_eval_loss = checkpoint['best_eval_loss'] |
|
|
|
|
|
logger.info(f"Checkpoint loaded from {path}") |
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
|
|
|
config = { |
|
|
|
|
|
'model_dim': 1536, |
|
|
'vocab_size': 151665, |
|
|
'n_layers': 12, |
|
|
'n_heads': 12, |
|
|
'n_kv_heads': 4, |
|
|
'max_seq_len': 512, |
|
|
'dropout': 0.0, |
|
|
'use_moe': False, |
|
|
|
|
|
'batch_size': 2, |
|
|
'gradient_accumulation_steps': 8, |
|
|
'learning_rate': 1e-5, |
|
|
'weight_decay': 0.01, |
|
|
'num_epochs': 3, |
|
|
'max_grad_norm': 1.0, |
|
|
|
|
|
|
|
|
'data_mix': 'simple_instruct', |
|
|
'max_samples_train': 20000, |
|
|
'max_samples_eval': 1000, |
|
|
'max_length': 512, |
|
|
'num_workers': 4, |
|
|
|
|
|
|
|
|
'do_rlhf': False, |
|
|
'preference_dataset': 'hh_rlhf', |
|
|
'grpo_iterations': 3, |
|
|
'grpo_kl_coef': 0.04, |
|
|
'grpo_group_size': 4, |
|
|
|
|
|
|
|
|
'pretrain_checkpoint': '/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt', |
|
|
'checkpoint_dir': 'checkpoints/posttrain', |
|
|
'log_interval': 50, |
|
|
'eval_interval': 500, |
|
|
'save_interval': 1000, |
|
|
} |
|
|
|
|
|
logger.info("Configuration:") |
|
|
logger.info(json.dumps(config, indent=2)) |
|
|
|
|
|
|
|
|
logger.info("\nInitializing tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"Qwen/Qwen2.5-7B-Instruct", |
|
|
use_fast=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
|
|
|
config['vocab_size'] = len(tokenizer) |
|
|
|
|
|
|
|
|
logger.info("\nInitializing model...") |
|
|
model = MultiModalDenseTransformer( |
|
|
model_dim=config['model_dim'], |
|
|
vocab_size=config['vocab_size'], |
|
|
n_layers=config['n_layers'], |
|
|
n_heads=config['n_heads'], |
|
|
n_kv_heads=config['n_kv_heads'], |
|
|
max_seq_len=config['max_seq_len'], |
|
|
dropout=config['dropout'], |
|
|
use_moe=config['use_moe'], |
|
|
use_gradient_checkpointing=False, |
|
|
rope_scaling_type="yarn", |
|
|
use_multimodal_fusion=False, |
|
|
use_contrastive=False |
|
|
) |
|
|
|
|
|
if config['pretrain_checkpoint']: |
|
|
logger.info(f"Loading pretrain checkpoint: {config['pretrain_checkpoint']}") |
|
|
checkpoint = torch.load(config['pretrain_checkpoint']) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("PHASE 1: Supervised Fine-Tuning") |
|
|
logger.info("="*80) |
|
|
|
|
|
|
|
|
train_dataloader = create_posttrain_dataloader( |
|
|
mix_name=config['data_mix'], |
|
|
tokenizer=tokenizer, |
|
|
batch_size=config['batch_size'], |
|
|
num_workers=config['num_workers'], |
|
|
max_length=config['max_length'], |
|
|
max_samples=config['max_samples_train'], |
|
|
split='train', |
|
|
shuffle=True |
|
|
) |
|
|
|
|
|
eval_dataloader = create_posttrain_dataloader( |
|
|
mix_name=config['data_mix'], |
|
|
tokenizer=tokenizer, |
|
|
batch_size=config['batch_size'] * 2, |
|
|
num_workers=config['num_workers'], |
|
|
max_length=config['max_length'], |
|
|
max_samples=config['max_samples_eval'], |
|
|
split='train', |
|
|
shuffle=False |
|
|
) |
|
|
|
|
|
|
|
|
trainer = PostTrainer( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
learning_rate=config['learning_rate'], |
|
|
weight_decay=config['weight_decay'], |
|
|
num_epochs=config['num_epochs'], |
|
|
gradient_accumulation_steps=config['gradient_accumulation_steps'], |
|
|
max_grad_norm=config['max_grad_norm'], |
|
|
log_interval=config['log_interval'], |
|
|
eval_interval=config['eval_interval'], |
|
|
save_interval=config['save_interval'], |
|
|
checkpoint_dir=config['checkpoint_dir'] |
|
|
) |
|
|
|
|
|
trainer.train(train_dataloader, eval_dataloader) |
|
|
|
|
|
if config['do_rlhf']: |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("PHASE 2: RLHF with GRPO") |
|
|
logger.info("="*80) |
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("\nTraining Reward Model...") |
|
|
|
|
|
reward_base_model = copy.deepcopy(model) |
|
|
reward_model = RewardModel(reward_base_model, use_value_head=True) |
|
|
|
|
|
preference_dataloader = create_preference_dataloader( |
|
|
dataset_name=config['preference_dataset'], |
|
|
tokenizer=tokenizer, |
|
|
batch_size=config['batch_size'], |
|
|
num_workers=config['num_workers'], |
|
|
max_samples=5000, |
|
|
split='train' |
|
|
) |
|
|
|
|
|
reward_trainer = RewardModelTrainer( |
|
|
reward_model=reward_model, |
|
|
learning_rate=1e-5 |
|
|
) |
|
|
|
|
|
reward_trainer.train(preference_dataloader, num_epochs=1) |
|
|
|
|
|
|
|
|
logger.info("\nStarting GRPO Training...") |
|
|
|
|
|
ref_model = copy.deepcopy(model) |
|
|
ref_model.eval() |
|
|
|
|
|
grpo_trainer = GRPOTrainer( |
|
|
actor_model=model, |
|
|
reward_model=reward_model, |
|
|
ref_model=ref_model, |
|
|
tokenizer=tokenizer, |
|
|
learning_rate=1e-6, |
|
|
kl_coef=config['grpo_kl_coef'], |
|
|
group_size=config['grpo_group_size'], |
|
|
update_batch_size=2, |
|
|
use_amp=True |
|
|
) |
|
|
|
|
|
|
|
|
prompt_dataloader = create_posttrain_dataloader( |
|
|
mix_name=config['data_mix'], |
|
|
tokenizer=tokenizer, |
|
|
batch_size=4, |
|
|
num_workers=2, |
|
|
max_samples=1000, |
|
|
split='train' |
|
|
) |
|
|
|
|
|
|
|
|
prompts = [] |
|
|
for batch in prompt_dataloader: |
|
|
if batch and batch.get('instruction') is not None: |
|
|
prompts.append(batch['instruction']) |
|
|
if len(prompts) >= 200: |
|
|
break |
|
|
|
|
|
if prompts: |
|
|
prompt_tensor = torch.cat(prompts[:200], dim=0) |
|
|
from torch.utils.data import TensorDataset, DataLoader |
|
|
prompt_loader = DataLoader( |
|
|
TensorDataset(prompt_tensor), |
|
|
batch_size=4 |
|
|
) |
|
|
|
|
|
grpo_trainer.train( |
|
|
prompt_loader, |
|
|
num_iterations=config['grpo_iterations'], |
|
|
max_gen_len=50, |
|
|
save_path=config['checkpoint_dir'] + "/grpo" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in RLHF: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("All Training Complete!") |
|
|
logger.info("="*80) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |