| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
|
|
| import json |
| import logging |
| import os |
| import re |
| from enum import Enum |
|
|
| import torch |
|
|
| import verl.utils.hdfs_io as hdfs_io |
| from verl.single_controller import WorkerGroup |
| from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, get_checkpoint_tracker_filename |
| from verl.utils.logger import log_with_rank |
| from verl.workers.engine import BaseEngine |
|
|
|
|
| def extract_step(path): |
| match = re.search(r"global_step_(\d+)", path) |
| if match: |
| return int(match.group(1)) |
| return None |
|
|
|
|
| logger = logging.getLogger(__file__) |
| logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) |
|
|
|
|
| class OrchestrationMode(Enum): |
| SPMD = 0 |
| RAY = 1 |
|
|
|
|
| class CheckpointHandler: |
| """ |
| Checkpoint handler handles the path, global_step of a checkpoint folder. |
| Currently, it only works with a single model. |
| We can expand it to support multiple models. It is expected to be used with SPMD style (e.g., torchrun) |
| """ |
|
|
| def __init__( |
| self, |
| engine: BaseEngine | WorkerGroup, |
| train_dataloader, |
| *, |
| default_local_dir, |
| max_ckpt_to_keep=None, |
| default_hdfs_dir=None, |
| resume_mode="auto", |
| resume_from_path=None, |
| mode=OrchestrationMode.SPMD, |
| lora_train_meta=None, |
| ): |
| self.default_local_dir = default_local_dir |
| self.max_ckpt_to_keep = max_ckpt_to_keep |
| self.default_hdfs_dir = default_hdfs_dir |
| self.resume_mode = resume_mode |
| self.resume_from_path = resume_from_path |
| self.engine = engine |
| self.train_dataloader = train_dataloader |
| self.mode = mode |
| self.lora_train_meta = lora_train_meta |
|
|
| if self.mode == OrchestrationMode.SPMD: |
| self.rank = torch.distributed.get_rank() |
| self.is_mp_src_rank_with_outputs = self.engine.is_mp_src_rank_with_outputs() |
| self.dp_rank = self.engine.get_data_parallel_rank() |
| elif self.mode == OrchestrationMode.RAY: |
| self.rank = 0 |
| self.is_mp_src_rank_with_outputs = True |
| self.dp_rank = 0 |
| else: |
| raise ValueError(f"Unknown {self.mode=}") |
|
|
| def save_checkpoint(self, step): |
| """Save checkpoint using FSDPCheckpointManager with improved tracking""" |
| from verl.utils.fs import local_mkdir_safe |
|
|
| |
| local_global_step_folder = os.path.join(self.default_local_dir, f"global_step_{step}") |
| if self.rank == 0: |
| print(f"Saving checkpoint to: {local_global_step_folder}") |
|
|
| |
| max_ckpt_to_keep = self.max_ckpt_to_keep |
|
|
| |
| self.engine.save_checkpoint( |
| local_path=local_global_step_folder, global_step=step, max_ckpt_to_keep=max_ckpt_to_keep |
| ) |
|
|
| |
| |
| if self.rank == 0 and self.lora_train_meta is not None: |
| local_mkdir_safe(local_global_step_folder) |
| lora_meta_path = os.path.join(local_global_step_folder, "lora_train_meta.json") |
| with open(lora_meta_path, "w", encoding="utf-8") as f: |
| json.dump(self.lora_train_meta, f, ensure_ascii=False, indent=4) |
| print(f"Saved LoRA rank/alpha metadata to: {lora_meta_path}") |
|
|
| if self.is_mp_src_rank_with_outputs: |
| dp_rank = self.dp_rank |
| local_mkdir_safe(local_global_step_folder) |
| dataloader_local_path = os.path.join(local_global_step_folder, f"data_{dp_rank}.pt") |
|
|
| |
| dataloader_state_dict = self.train_dataloader.state_dict() |
| torch.save(dataloader_state_dict, dataloader_local_path) |
| print(f"Saved dataloader state to: {dataloader_local_path}") |
|
|
| if self.rank == 0: |
| |
| tracker_file = get_checkpoint_tracker_filename(self.default_local_dir) |
| temp_tracker_file = tracker_file + ".tmp" |
| with open(temp_tracker_file, "w") as f: |
| f.write(str(step)) |
| os.rename(temp_tracker_file, tracker_file) |
| print(f"Updated checkpoint tracker: {tracker_file}") |
|
|
| |
| if self.rank == 0 and self.default_hdfs_dir: |
| hdfs_io.makedirs(self.default_hdfs_dir, exist_ok=True) |
| hdfs_io.copy(src=local_global_step_folder, dst=self.default_hdfs_dir, dirs_exist_ok=True) |
|
|
| if self.mode == OrchestrationMode.SPMD: |
| torch.distributed.barrier() |
|
|
| def load_checkpoint(self): |
| |
| checkpoint_path = self._determine_resume_path() |
|
|
| if checkpoint_path is None: |
| return 0 |
|
|
| |
| resume_step = extract_step(checkpoint_path) |
| if resume_step is None: |
| log_with_rank( |
| f"Warning: Could not extract step number from {checkpoint_path}, starting from step 0", |
| logger=logger, |
| rank=self.rank, |
| level=logging.WARNING, |
| log_only_rank_0=True, |
| ) |
| return 0 |
| self.resume_global_step = resume_step |
|
|
| |
| self.engine.load_checkpoint(checkpoint_path) |
| |
| self._load_dataloader_state(checkpoint_path) |
|
|
| return resume_step |
|
|
| def _load_dataloader_state(self, checkpoint_path: str): |
| """Load dataloader state from checkpoint""" |
| dp_rank = self.dp_rank |
| dataloader_path = os.path.join(checkpoint_path, f"data_{dp_rank}.pt") |
|
|
| if os.path.exists(dataloader_path): |
| |
| dataloader_state_dict = torch.load(dataloader_path, map_location="cpu", weights_only=False) |
| self.train_dataloader.load_state_dict(dataloader_state_dict) |
|
|
| log_with_rank( |
| f"Successfully loaded dataloader state from {dataloader_path}", |
| logger=logger, |
| rank=self.rank, |
| log_only_rank_0=True, |
| ) |
|
|
| else: |
| log_with_rank( |
| f"Warning: No dataloader state found at {dataloader_path}, will start from scratch", |
| logger=logger, |
| rank=self.rank, |
| level=logging.WARNING, |
| log_only_rank_0=True, |
| ) |
|
|
| def _determine_resume_path(self): |
| """Determine the path to resume from based on resume_mode configuration""" |
| resume_mode = self.resume_mode |
| resume_from_path = self.resume_from_path |
|
|
| if resume_mode == "disable": |
| return None |
| elif resume_mode == "auto": |
| if resume_from_path is not None: |
| assert os.path.exists(resume_from_path), ( |
| "resume_from_path must be null or an existing path when resume_mode is 'auto'" |
| ) |
| assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps" |
| return resume_from_path |
| |
| return self._find_latest_checkpoint() |
| elif resume_mode == "resume_path": |
| assert os.path.exists(resume_from_path), ( |
| "resume_from_path must be an existing path when resume_mode is 'resume_path'" |
| ) |
| assert "global_step_" in resume_from_path, "resume_from_path must specify the global_steps" |
| return resume_from_path |
| else: |
| raise ValueError(f"Invalid resume_mode: {resume_mode}. Must be 'auto', 'disable', or 'resume_path'") |
|
|
| def _find_latest_checkpoint(self): |
| """Find the latest checkpoint in the default local directory""" |
| checkpoint_dir = self.default_local_dir |
|
|
| if not os.path.exists(checkpoint_dir): |
| return None |
|
|
| latest_checkpoint = find_latest_ckpt_path(checkpoint_dir) |
|
|
| if latest_checkpoint and self.rank == 0: |
| step_num = extract_step(latest_checkpoint) |
| print(f"Found latest checkpoint: {latest_checkpoint} (step {step_num})") |
|
|
| return latest_checkpoint |
|
|