import os import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from transformers import AutoTokenizer from torch.utils.data import DataLoader, Dataset import json import logging from tqdm import tqdm import glob from datetime import datetime import gc import warnings warnings.filterwarnings("ignore", category=FutureWarning) from model import MultiModalDenseTransformer from dcpo import DCPOTrainer def setup_distributed(): if "RANK" in os.environ and "WORLD_SIZE" in os.environ: dist.init_process_group(backend="nccl") rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) torch.cuda.set_device(local_rank) if rank == 0: print(f"Initialized DDP: Rank {rank}/{world_size}") return rank, local_rank, world_size else: print("Initialized Single GPU Mode") return 0, 0, 1 RANK, LOCAL_RANK, WORLD_SIZE = setup_distributed() IS_MAIN = RANK == 0 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO if IS_MAIN else logging.WARNING) class MathDataset(Dataset): def __init__(self, path): self.data = [] with open(path, 'r', encoding='utf-8') as f: for line in f: if line.strip(): self.data.append(json.loads(line)) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def math_collate(batch): return { 'prompt': [item['prompt'] for item in batch], 'ground_truth': [item['ground_truth'] for item in batch] } def main(): CONFIG = { 'sft_checkpoint': '/root/checkpoints/dcpo_posttrain_round3/step_1200.pt', 'data_path': '/root/dataset/r1_zero_math.jsonl', 'save_dir': '/root/checkpoints/dcpo_training', 'resume_from': None, 'model_dim': 1536, 'n_layers': 12, 'n_heads': 12, 'n_kv_heads': 4, 'group_size': 4, 'batch_size': 1, 'learning_rate': 1e-6, 'max_steps': 5000, 'max_gen_len': 512, 'save_interval': 1400, 'dcpo_eps_low': 0.16, 'dcpo_eps_high': 0.2, 'dcpo_r_max': 10.0, 'gradient_accumulation_steps': 8, 'inner_batch_size': 4, 'use_reference_comparison': True, 'use_progressive_reward': False, 'phase1_steps': 2000, 'phase2_steps': 4000, } file_handler = None if IS_MAIN: os.makedirs(CONFIG['save_dir'], exist_ok=True) current_time = datetime.now().strftime('%Y%m%d_%H%M%S') log_file = os.path.join(CONFIG['save_dir'], f"dcpo_train_{current_time}.log") file_handler = logging.FileHandler(log_file, encoding='utf-8') file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(file_handler) metrics_file = os.path.join(CONFIG['save_dir'], "metrics.jsonl") if not os.path.exists(metrics_file): with open(metrics_file, 'w', encoding='utf-8') as f: pass tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id def create_model(): return MultiModalDenseTransformer( model_dim=CONFIG['model_dim'], vocab_size=len(tokenizer), n_layers=CONFIG['n_layers'], n_heads=CONFIG['n_heads'], n_kv_heads=CONFIG['n_kv_heads'], max_seq_len=2048, use_gradient_checkpointing=True ) device = torch.device(f"cuda:{LOCAL_RANK}") if IS_MAIN: print("Initializing Actor Model...") actor = create_model().to(device) ref = None if WORLD_SIZE > 1: actor = DDP(actor, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) trainer = DCPOTrainer( actor_model=actor, ref_model=ref, tokenizer=tokenizer, learning_rate=CONFIG['learning_rate'], group_size=CONFIG['group_size'], eps_low=CONFIG['dcpo_eps_low'], eps_high=CONFIG['dcpo_eps_high'], r_max=CONFIG['dcpo_r_max'], use_amp=True, gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'], inner_batch_size=CONFIG['inner_batch_size'], use_reference_comparison=CONFIG['use_reference_comparison'], use_progressive_reward=CONFIG['use_progressive_reward'], phase1_steps=CONFIG['phase1_steps'], phase2_steps=CONFIG['phase2_steps'] ) start_step = 0 samples_seen = 0 if CONFIG['resume_from']: resume_path = CONFIG['resume_from'] if IS_MAIN: print(f"Resuming from: {resume_path}") checkpoint = torch.load(resume_path, map_location='cpu') if WORLD_SIZE > 1: actor.module.load_state_dict(checkpoint['model_state_dict']) else: actor.load_state_dict(checkpoint['model_state_dict']) if 'trainer_state_dict' in checkpoint: trainer.load_state_dict(checkpoint['trainer_state_dict']) if 'rng_state' in checkpoint: torch.set_rng_state(checkpoint['rng_state']) if 'cuda_rng_state' in checkpoint: try: torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state']) except: torch.cuda.set_rng_state(checkpoint['cuda_rng_state'][LOCAL_RANK]) start_step = checkpoint.get('step', 0) + 1 samples_seen = checkpoint.get('samples_seen', start_step * CONFIG['batch_size'] * WORLD_SIZE) if CONFIG['use_progressive_reward']: trainer.update_step(start_step) if IS_MAIN: print(f"Restored progressive reward state to step {start_step}") del checkpoint gc.collect() torch.cuda.empty_cache() else: if IS_MAIN: print(f"Loading SFT checkpoint: {CONFIG['sft_checkpoint']}") checkpoint = torch.load(CONFIG['sft_checkpoint'], map_location='cpu') state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} if WORLD_SIZE > 1: actor.module.load_state_dict(new_state_dict) else: actor.load_state_dict(new_state_dict) del checkpoint, state_dict, new_state_dict gc.collect() torch.cuda.empty_cache() dataset = MathDataset(CONFIG['data_path']) if WORLD_SIZE > 1: sampler = torch.utils.data.DistributedSampler( dataset, num_replicas=WORLD_SIZE, rank=RANK, shuffle=True, seed=42 ) else: sampler = None dataloader = DataLoader( dataset, batch_size=CONFIG['batch_size'], collate_fn=math_collate, sampler=sampler, shuffle=(sampler is None) ) if IS_MAIN: print(f"Starting Training from step {start_step}") if sampler: epoch = start_step // len(dataloader) sampler.set_epoch(epoch) data_iter = iter(dataloader) steps_in_epoch = start_step % len(dataloader) if start_step > 0 and steps_in_epoch > 0: if IS_MAIN: print(f"Fast-forwarding dataloader by {steps_in_epoch} steps...") for _ in range(steps_in_epoch): try: next(data_iter) except StopIteration: if sampler: epoch += 1 sampler.set_epoch(epoch) data_iter = iter(dataloader) next(data_iter) progress_bar = tqdm( range(start_step, CONFIG['max_steps']), disable=not IS_MAIN, initial=start_step, total=CONFIG['max_steps'], ncols=120, mininterval=1.0 ) running_reward = 0.0 running_loss = 0.0 for step in progress_bar: try: if CONFIG['use_progressive_reward']: trainer.update_step(step) try: batch = next(data_iter) except StopIteration: if sampler: epoch = step // len(dataloader) + 1 sampler.set_epoch(epoch) data_iter = iter(dataloader) batch = next(data_iter) samples_seen += CONFIG['batch_size'] * WORLD_SIZE # 生成 + SAS experience = trainer.generate_and_prepare( batch, max_gen_len=CONFIG['max_gen_len'] ) step_reward = experience['rewards'].mean().item() if running_reward == 0: running_reward = step_reward else: running_reward = 0.95 * running_reward + 0.05 * step_reward loss = trainer.train_step(experience) status_dict = {"Rw": f"{running_reward:.2f}"} if CONFIG['use_progressive_reward'] and hasattr(trainer.math_verifier, 'current_phase'): status_dict["Ph"] = f"{trainer.math_verifier.current_phase}" if loss is not None: if running_loss == 0: running_loss = loss else: running_loss = 0.9 * running_loss + 0.1 * loss status_dict["Ls"] = f"{running_loss:.3f}" if IS_MAIN: current_lr = trainer.optimizer.param_groups[0]['lr'] metrics_data = { "step": step, "running_reward": float(running_reward), "reward": float(step_reward), "loss": float(loss), "lr": float(current_lr), "samples_seen": samples_seen, "timestamp": datetime.now().isoformat() } if CONFIG['use_progressive_reward'] and hasattr(trainer.math_verifier, 'current_phase'): metrics_data['reward_phase'] = trainer.math_verifier.current_phase with open(os.path.join(CONFIG['save_dir'], "metrics.jsonl"), "a", encoding='utf-8') as f: f.write(json.dumps(metrics_data) + "\n") if step % 10 == 0: log_msg = f"Step {step} | Reward: {step_reward:.4f} | Loss: {loss:.4f}" progress_bar.write(log_msg) if file_handler: file_handler.emit(logging.LogRecord( name="train", level=logging.INFO, pathname=__file__, lineno=0, msg=log_msg, args=(), exc_info=None )) else: status_dict["St"] = "Acc" progress_bar.set_description(f"{' '.join([f'{k}:{v}' for k,v in status_dict.items()])}") is_accum_boundary = (len(trainer.experience_buffer) == 0) if step > 0 and step % CONFIG['save_interval'] == 0 and IS_MAIN: if not is_accum_boundary: msg = "Saving checkpoint during gradient accumulation! Partial gradients will be lost." progress_bar.write(msg) if file_handler: logger.warning(msg) save_path = f"{CONFIG['save_dir']}/step_{step}.pt" model_to_save = actor.module if hasattr(actor, 'module') else actor torch.save({ 'step': step, 'samples_seen': samples_seen, 'model_state_dict': model_to_save.state_dict(), 'trainer_state_dict': trainer.state_dict(), 'rng_state': torch.get_rng_state(), 'cuda_rng_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None }, save_path) msg = f"Checkpoint saved: {save_path}" progress_bar.write(msg) if file_handler: logger.info(msg) del experience del batch except Exception as e: err_msg = f"Step {step} Error: {e}" if IS_MAIN: progress_bar.write(err_msg) logger.error(err_msg) import traceback traceback.print_exc() continue if IS_MAIN: final_path = f"{CONFIG['save_dir']}/final_dcpo.pt" model_to_save = actor.module if hasattr(actor, 'module') else actor torch.save({'model_state_dict': model_to_save.state_dict()}, final_path) print("DCPO Training Finished.") if WORLD_SIZE > 1: dist.destroy_process_group() if __name__ == "__main__": main()