| import os |
| import torch |
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from transformers import AutoTokenizer |
| from torch.utils.data import DataLoader, Dataset |
| import json |
| import logging |
| from tqdm import tqdm |
| import glob |
| from datetime import datetime |
| import gc |
| import warnings |
|
|
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| from model import MultiModalDenseTransformer |
| from dcpo import DCPOTrainer |
|
|
| def setup_distributed(): |
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
| dist.init_process_group(backend="nccl") |
| rank = int(os.environ["RANK"]) |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| torch.cuda.set_device(local_rank) |
| if rank == 0: |
| print(f"Initialized DDP: Rank {rank}/{world_size}") |
| return rank, local_rank, world_size |
| else: |
| print("Initialized Single GPU Mode") |
| return 0, 0, 1 |
|
|
| RANK, LOCAL_RANK, WORLD_SIZE = setup_distributed() |
| IS_MAIN = RANK == 0 |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO if IS_MAIN else logging.WARNING) |
|
|
| class MathDataset(Dataset): |
| def __init__(self, path): |
| self.data = [] |
| with open(path, 'r', encoding='utf-8') as f: |
| for line in f: |
| if line.strip(): |
| self.data.append(json.loads(line)) |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| return self.data[idx] |
|
|
| def math_collate(batch): |
| return { |
| 'prompt': [item['prompt'] for item in batch], |
| 'ground_truth': [item['ground_truth'] for item in batch] |
| } |
|
|
| def main(): |
| CONFIG = { |
| 'sft_checkpoint': '/root/checkpoints/dcpo_posttrain_round3/step_1200.pt', |
| 'data_path': '/root/dataset/r1_zero_math.jsonl', |
| 'save_dir': '/root/checkpoints/dcpo_training', |
| 'resume_from': None, |
| |
| 'model_dim': 1536, |
| 'n_layers': 12, |
| 'n_heads': 12, |
| 'n_kv_heads': 4, |
| |
| 'group_size': 4, |
| 'batch_size': 1, |
| 'learning_rate': 1e-6, |
| 'max_steps': 5000, |
| 'max_gen_len': 512, |
| 'save_interval': 1400, |
| |
| 'dcpo_eps_low': 0.16, |
| 'dcpo_eps_high': 0.2, |
| 'dcpo_r_max': 10.0, |
| |
| 'gradient_accumulation_steps': 8, |
| 'inner_batch_size': 4, |
| |
| 'use_reference_comparison': True, |
| 'use_progressive_reward': False, |
| 'phase1_steps': 2000, |
| 'phase2_steps': 4000, |
| } |
| |
| file_handler = None |
| if IS_MAIN: |
| os.makedirs(CONFIG['save_dir'], exist_ok=True) |
| current_time = datetime.now().strftime('%Y%m%d_%H%M%S') |
| log_file = os.path.join(CONFIG['save_dir'], f"dcpo_train_{current_time}.log") |
| |
| file_handler = logging.FileHandler(log_file, encoding='utf-8') |
| file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) |
| logger.addHandler(file_handler) |
|
|
| metrics_file = os.path.join(CONFIG['save_dir'], "metrics.jsonl") |
| if not os.path.exists(metrics_file): |
| with open(metrics_file, 'w', encoding='utf-8') as f: |
| pass |
|
|
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
| def create_model(): |
| return MultiModalDenseTransformer( |
| model_dim=CONFIG['model_dim'], |
| vocab_size=len(tokenizer), |
| n_layers=CONFIG['n_layers'], |
| n_heads=CONFIG['n_heads'], |
| n_kv_heads=CONFIG['n_kv_heads'], |
| max_seq_len=2048, |
| use_gradient_checkpointing=True |
| ) |
| |
| device = torch.device(f"cuda:{LOCAL_RANK}") |
| |
| if IS_MAIN: |
| print("Initializing Actor Model...") |
| |
| actor = create_model().to(device) |
| ref = None |
|
|
| if WORLD_SIZE > 1: |
| actor = DDP(actor, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) |
|
|
| trainer = DCPOTrainer( |
| actor_model=actor, |
| ref_model=ref, |
| tokenizer=tokenizer, |
| learning_rate=CONFIG['learning_rate'], |
| group_size=CONFIG['group_size'], |
| eps_low=CONFIG['dcpo_eps_low'], |
| eps_high=CONFIG['dcpo_eps_high'], |
| r_max=CONFIG['dcpo_r_max'], |
| use_amp=True, |
| gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'], |
| inner_batch_size=CONFIG['inner_batch_size'], |
| use_reference_comparison=CONFIG['use_reference_comparison'], |
| use_progressive_reward=CONFIG['use_progressive_reward'], |
| phase1_steps=CONFIG['phase1_steps'], |
| phase2_steps=CONFIG['phase2_steps'] |
| ) |
|
|
| start_step = 0 |
| samples_seen = 0 |
| |
| if CONFIG['resume_from']: |
| resume_path = CONFIG['resume_from'] |
| if IS_MAIN: |
| print(f"Resuming from: {resume_path}") |
| |
| checkpoint = torch.load(resume_path, map_location='cpu') |
| |
| if WORLD_SIZE > 1: |
| actor.module.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| actor.load_state_dict(checkpoint['model_state_dict']) |
| |
| if 'trainer_state_dict' in checkpoint: |
| trainer.load_state_dict(checkpoint['trainer_state_dict']) |
| |
| if 'rng_state' in checkpoint: |
| torch.set_rng_state(checkpoint['rng_state']) |
| if 'cuda_rng_state' in checkpoint: |
| try: |
| torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state']) |
| except: |
| torch.cuda.set_rng_state(checkpoint['cuda_rng_state'][LOCAL_RANK]) |
| |
| start_step = checkpoint.get('step', 0) + 1 |
| samples_seen = checkpoint.get('samples_seen', start_step * CONFIG['batch_size'] * WORLD_SIZE) |
| |
| if CONFIG['use_progressive_reward']: |
| trainer.update_step(start_step) |
| if IS_MAIN: |
| print(f"Restored progressive reward state to step {start_step}") |
| |
| del checkpoint |
| gc.collect() |
| torch.cuda.empty_cache() |
| else: |
| if IS_MAIN: |
| print(f"Loading SFT checkpoint: {CONFIG['sft_checkpoint']}") |
| checkpoint = torch.load(CONFIG['sft_checkpoint'], map_location='cpu') |
| state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint |
| new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
| |
| if WORLD_SIZE > 1: |
| actor.module.load_state_dict(new_state_dict) |
| else: |
| actor.load_state_dict(new_state_dict) |
| |
| del checkpoint, state_dict, new_state_dict |
| gc.collect() |
| torch.cuda.empty_cache() |
| |
| dataset = MathDataset(CONFIG['data_path']) |
| if WORLD_SIZE > 1: |
| sampler = torch.utils.data.DistributedSampler( |
| dataset, num_replicas=WORLD_SIZE, rank=RANK, shuffle=True, seed=42 |
| ) |
| else: |
| sampler = None |
| |
| dataloader = DataLoader( |
| dataset, batch_size=CONFIG['batch_size'], |
| collate_fn=math_collate, sampler=sampler, shuffle=(sampler is None) |
| ) |
| |
| if IS_MAIN: |
| print(f"Starting Training from step {start_step}") |
| |
| if sampler: |
| epoch = start_step // len(dataloader) |
| sampler.set_epoch(epoch) |
| |
| data_iter = iter(dataloader) |
| steps_in_epoch = start_step % len(dataloader) |
| |
| if start_step > 0 and steps_in_epoch > 0: |
| if IS_MAIN: |
| print(f"Fast-forwarding dataloader by {steps_in_epoch} steps...") |
| |
| for _ in range(steps_in_epoch): |
| try: |
| next(data_iter) |
| except StopIteration: |
| if sampler: |
| epoch += 1 |
| sampler.set_epoch(epoch) |
| data_iter = iter(dataloader) |
| next(data_iter) |
| |
| progress_bar = tqdm( |
| range(start_step, CONFIG['max_steps']), |
| disable=not IS_MAIN, |
| initial=start_step, |
| total=CONFIG['max_steps'], |
| ncols=120, |
| mininterval=1.0 |
| ) |
| |
| running_reward = 0.0 |
| running_loss = 0.0 |
| |
| for step in progress_bar: |
| try: |
| if CONFIG['use_progressive_reward']: |
| trainer.update_step(step) |
| |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| if sampler: |
| epoch = step // len(dataloader) + 1 |
| sampler.set_epoch(epoch) |
| data_iter = iter(dataloader) |
| batch = next(data_iter) |
| |
| samples_seen += CONFIG['batch_size'] * WORLD_SIZE |
| |
| |
| experience = trainer.generate_and_prepare( |
| batch, |
| max_gen_len=CONFIG['max_gen_len'] |
| ) |
| |
| step_reward = experience['rewards'].mean().item() |
| if running_reward == 0: running_reward = step_reward |
| else: running_reward = 0.95 * running_reward + 0.05 * step_reward |
| |
|
|
| loss = trainer.train_step(experience) |
| |
| status_dict = {"Rw": f"{running_reward:.2f}"} |
|
|
| if CONFIG['use_progressive_reward'] and hasattr(trainer.math_verifier, 'current_phase'): |
| status_dict["Ph"] = f"{trainer.math_verifier.current_phase}" |
| |
| if loss is not None: |
| if running_loss == 0: running_loss = loss |
| else: running_loss = 0.9 * running_loss + 0.1 * loss |
| status_dict["Ls"] = f"{running_loss:.3f}" |
| |
| if IS_MAIN: |
| current_lr = trainer.optimizer.param_groups[0]['lr'] |
| metrics_data = { |
| "step": step, |
| "running_reward": float(running_reward), |
| "reward": float(step_reward), |
| "loss": float(loss), |
| "lr": float(current_lr), |
| "samples_seen": samples_seen, |
| "timestamp": datetime.now().isoformat() |
| } |
| |
| if CONFIG['use_progressive_reward'] and hasattr(trainer.math_verifier, 'current_phase'): |
| metrics_data['reward_phase'] = trainer.math_verifier.current_phase |
| |
| with open(os.path.join(CONFIG['save_dir'], "metrics.jsonl"), "a", encoding='utf-8') as f: |
| f.write(json.dumps(metrics_data) + "\n") |
| |
| if step % 10 == 0: |
| log_msg = f"Step {step} | Reward: {step_reward:.4f} | Loss: {loss:.4f}" |
| progress_bar.write(log_msg) |
| if file_handler: |
| file_handler.emit(logging.LogRecord( |
| name="train", level=logging.INFO, pathname=__file__, lineno=0, |
| msg=log_msg, args=(), exc_info=None |
| )) |
| else: |
| status_dict["St"] = "Acc" |
| |
| progress_bar.set_description(f"{' '.join([f'{k}:{v}' for k,v in status_dict.items()])}") |
|
|
| is_accum_boundary = (len(trainer.experience_buffer) == 0) |
| |
| if step > 0 and step % CONFIG['save_interval'] == 0 and IS_MAIN: |
| if not is_accum_boundary: |
| msg = "Saving checkpoint during gradient accumulation! Partial gradients will be lost." |
| progress_bar.write(msg) |
| if file_handler: logger.warning(msg) |
| |
| save_path = f"{CONFIG['save_dir']}/step_{step}.pt" |
| model_to_save = actor.module if hasattr(actor, 'module') else actor |
| |
| torch.save({ |
| 'step': step, |
| 'samples_seen': samples_seen, |
| 'model_state_dict': model_to_save.state_dict(), |
| 'trainer_state_dict': trainer.state_dict(), |
| 'rng_state': torch.get_rng_state(), |
| 'cuda_rng_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None |
| }, save_path) |
| |
| msg = f"Checkpoint saved: {save_path}" |
| progress_bar.write(msg) |
| if file_handler: logger.info(msg) |
| |
| del experience |
| del batch |
| |
| except Exception as e: |
| err_msg = f"Step {step} Error: {e}" |
| if IS_MAIN: |
| progress_bar.write(err_msg) |
| logger.error(err_msg) |
| import traceback |
| traceback.print_exc() |
| continue |
|
|
| if IS_MAIN: |
| final_path = f"{CONFIG['save_dir']}/final_dcpo.pt" |
| model_to_save = actor.module if hasattr(actor, 'module') else actor |
| torch.save({'model_state_dict': model_to_save.state_dict()}, final_path) |
| print("DCPO Training Finished.") |
| |
| if WORLD_SIZE > 1: |
| dist.destroy_process_group() |
|
|
| if __name__ == "__main__": |
| main() |