MultiModal / dcpo_train.py
szxllm's picture
Update dcpo_train.py
afd1085 verified
Raw
History Blame Contribute Delete
13.4 kB
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import json
import logging
from tqdm import tqdm
import glob
from datetime import datetime
import gc
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
from model import MultiModalDenseTransformer
from dcpo import DCPOTrainer
def setup_distributed():
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group(backend="nccl")
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank)
if rank == 0:
print(f"Initialized DDP: Rank {rank}/{world_size}")
return rank, local_rank, world_size
else:
print("Initialized Single GPU Mode")
return 0, 0, 1
RANK, LOCAL_RANK, WORLD_SIZE = setup_distributed()
IS_MAIN = RANK == 0
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO if IS_MAIN else logging.WARNING)
class MathDataset(Dataset):
def __init__(self, path):
self.data = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
self.data.append(json.loads(line))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def math_collate(batch):
return {
'prompt': [item['prompt'] for item in batch],
'ground_truth': [item['ground_truth'] for item in batch]
}
def main():
CONFIG = {
'sft_checkpoint': '/root/checkpoints/dcpo_posttrain_round3/step_1200.pt',
'data_path': '/root/dataset/r1_zero_math.jsonl',
'save_dir': '/root/checkpoints/dcpo_training',
'resume_from': None,
'model_dim': 1536,
'n_layers': 12,
'n_heads': 12,
'n_kv_heads': 4,
'group_size': 4,
'batch_size': 1,
'learning_rate': 1e-6,
'max_steps': 5000,
'max_gen_len': 512,
'save_interval': 1400,
'dcpo_eps_low': 0.16,
'dcpo_eps_high': 0.2,
'dcpo_r_max': 10.0,
'gradient_accumulation_steps': 8,
'inner_batch_size': 4,
'use_reference_comparison': True,
'use_progressive_reward': False,
'phase1_steps': 2000,
'phase2_steps': 4000,
}
file_handler = None
if IS_MAIN:
os.makedirs(CONFIG['save_dir'], exist_ok=True)
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = os.path.join(CONFIG['save_dir'], f"dcpo_train_{current_time}.log")
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
metrics_file = os.path.join(CONFIG['save_dir'], "metrics.jsonl")
if not os.path.exists(metrics_file):
with open(metrics_file, 'w', encoding='utf-8') as f:
pass
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
def create_model():
return MultiModalDenseTransformer(
model_dim=CONFIG['model_dim'],
vocab_size=len(tokenizer),
n_layers=CONFIG['n_layers'],
n_heads=CONFIG['n_heads'],
n_kv_heads=CONFIG['n_kv_heads'],
max_seq_len=2048,
use_gradient_checkpointing=True
)
device = torch.device(f"cuda:{LOCAL_RANK}")
if IS_MAIN:
print("Initializing Actor Model...")
actor = create_model().to(device)
ref = None
if WORLD_SIZE > 1:
actor = DDP(actor, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
trainer = DCPOTrainer(
actor_model=actor,
ref_model=ref,
tokenizer=tokenizer,
learning_rate=CONFIG['learning_rate'],
group_size=CONFIG['group_size'],
eps_low=CONFIG['dcpo_eps_low'],
eps_high=CONFIG['dcpo_eps_high'],
r_max=CONFIG['dcpo_r_max'],
use_amp=True,
gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'],
inner_batch_size=CONFIG['inner_batch_size'],
use_reference_comparison=CONFIG['use_reference_comparison'],
use_progressive_reward=CONFIG['use_progressive_reward'],
phase1_steps=CONFIG['phase1_steps'],
phase2_steps=CONFIG['phase2_steps']
)
start_step = 0
samples_seen = 0
if CONFIG['resume_from']:
resume_path = CONFIG['resume_from']
if IS_MAIN:
print(f"Resuming from: {resume_path}")
checkpoint = torch.load(resume_path, map_location='cpu')
if WORLD_SIZE > 1:
actor.module.load_state_dict(checkpoint['model_state_dict'])
else:
actor.load_state_dict(checkpoint['model_state_dict'])
if 'trainer_state_dict' in checkpoint:
trainer.load_state_dict(checkpoint['trainer_state_dict'])
if 'rng_state' in checkpoint:
torch.set_rng_state(checkpoint['rng_state'])
if 'cuda_rng_state' in checkpoint:
try:
torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state'])
except:
torch.cuda.set_rng_state(checkpoint['cuda_rng_state'][LOCAL_RANK])
start_step = checkpoint.get('step', 0) + 1
samples_seen = checkpoint.get('samples_seen', start_step * CONFIG['batch_size'] * WORLD_SIZE)
if CONFIG['use_progressive_reward']:
trainer.update_step(start_step)
if IS_MAIN:
print(f"Restored progressive reward state to step {start_step}")
del checkpoint
gc.collect()
torch.cuda.empty_cache()
else:
if IS_MAIN:
print(f"Loading SFT checkpoint: {CONFIG['sft_checkpoint']}")
checkpoint = torch.load(CONFIG['sft_checkpoint'], map_location='cpu')
state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
if WORLD_SIZE > 1:
actor.module.load_state_dict(new_state_dict)
else:
actor.load_state_dict(new_state_dict)
del checkpoint, state_dict, new_state_dict
gc.collect()
torch.cuda.empty_cache()
dataset = MathDataset(CONFIG['data_path'])
if WORLD_SIZE > 1:
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=WORLD_SIZE, rank=RANK, shuffle=True, seed=42
)
else:
sampler = None
dataloader = DataLoader(
dataset, batch_size=CONFIG['batch_size'],
collate_fn=math_collate, sampler=sampler, shuffle=(sampler is None)
)
if IS_MAIN:
print(f"Starting Training from step {start_step}")
if sampler:
epoch = start_step // len(dataloader)
sampler.set_epoch(epoch)
data_iter = iter(dataloader)
steps_in_epoch = start_step % len(dataloader)
if start_step > 0 and steps_in_epoch > 0:
if IS_MAIN:
print(f"Fast-forwarding dataloader by {steps_in_epoch} steps...")
for _ in range(steps_in_epoch):
try:
next(data_iter)
except StopIteration:
if sampler:
epoch += 1
sampler.set_epoch(epoch)
data_iter = iter(dataloader)
next(data_iter)
progress_bar = tqdm(
range(start_step, CONFIG['max_steps']),
disable=not IS_MAIN,
initial=start_step,
total=CONFIG['max_steps'],
ncols=120,
mininterval=1.0
)
running_reward = 0.0
running_loss = 0.0
for step in progress_bar:
try:
if CONFIG['use_progressive_reward']:
trainer.update_step(step)
try:
batch = next(data_iter)
except StopIteration:
if sampler:
epoch = step // len(dataloader) + 1
sampler.set_epoch(epoch)
data_iter = iter(dataloader)
batch = next(data_iter)
samples_seen += CONFIG['batch_size'] * WORLD_SIZE
# 生成 + SAS
experience = trainer.generate_and_prepare(
batch,
max_gen_len=CONFIG['max_gen_len']
)
step_reward = experience['rewards'].mean().item()
if running_reward == 0: running_reward = step_reward
else: running_reward = 0.95 * running_reward + 0.05 * step_reward
loss = trainer.train_step(experience)
status_dict = {"Rw": f"{running_reward:.2f}"}
if CONFIG['use_progressive_reward'] and hasattr(trainer.math_verifier, 'current_phase'):
status_dict["Ph"] = f"{trainer.math_verifier.current_phase}"
if loss is not None:
if running_loss == 0: running_loss = loss
else: running_loss = 0.9 * running_loss + 0.1 * loss
status_dict["Ls"] = f"{running_loss:.3f}"
if IS_MAIN:
current_lr = trainer.optimizer.param_groups[0]['lr']
metrics_data = {
"step": step,
"running_reward": float(running_reward),
"reward": float(step_reward),
"loss": float(loss),
"lr": float(current_lr),
"samples_seen": samples_seen,
"timestamp": datetime.now().isoformat()
}
if CONFIG['use_progressive_reward'] and hasattr(trainer.math_verifier, 'current_phase'):
metrics_data['reward_phase'] = trainer.math_verifier.current_phase
with open(os.path.join(CONFIG['save_dir'], "metrics.jsonl"), "a", encoding='utf-8') as f:
f.write(json.dumps(metrics_data) + "\n")
if step % 10 == 0:
log_msg = f"Step {step} | Reward: {step_reward:.4f} | Loss: {loss:.4f}"
progress_bar.write(log_msg)
if file_handler:
file_handler.emit(logging.LogRecord(
name="train", level=logging.INFO, pathname=__file__, lineno=0,
msg=log_msg, args=(), exc_info=None
))
else:
status_dict["St"] = "Acc"
progress_bar.set_description(f"{' '.join([f'{k}:{v}' for k,v in status_dict.items()])}")
is_accum_boundary = (len(trainer.experience_buffer) == 0)
if step > 0 and step % CONFIG['save_interval'] == 0 and IS_MAIN:
if not is_accum_boundary:
msg = "Saving checkpoint during gradient accumulation! Partial gradients will be lost."
progress_bar.write(msg)
if file_handler: logger.warning(msg)
save_path = f"{CONFIG['save_dir']}/step_{step}.pt"
model_to_save = actor.module if hasattr(actor, 'module') else actor
torch.save({
'step': step,
'samples_seen': samples_seen,
'model_state_dict': model_to_save.state_dict(),
'trainer_state_dict': trainer.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
}, save_path)
msg = f"Checkpoint saved: {save_path}"
progress_bar.write(msg)
if file_handler: logger.info(msg)
del experience
del batch
except Exception as e:
err_msg = f"Step {step} Error: {e}"
if IS_MAIN:
progress_bar.write(err_msg)
logger.error(err_msg)
import traceback
traceback.print_exc()
continue
if IS_MAIN:
final_path = f"{CONFIG['save_dir']}/final_dcpo.pt"
model_to_save = actor.module if hasattr(actor, 'module') else actor
torch.save({'model_state_dict': model_to_save.state_dict()}, final_path)
print("DCPO Training Finished.")
if WORLD_SIZE > 1:
dist.destroy_process_group()
if __name__ == "__main__":
main()