File size: 13,661 Bytes
bcdf9fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
from typing import Optional

import numpy as np
import torch
import torch.distributed
from megatron.core import mpu, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedObject

from verl.models.weight_loader_registry import get_weight_saver
from verl.utils.fs import is_non_local
from verl.utils.megatron_utils import (
    get_hf_model_checkpoint_path,
    get_model_checkpoint_path,
    get_optimizer_checkpoint_path,
    get_rng_states_checkpoint_path,
)

from .checkpoint_manager import BaseCheckpointManager


class MegatronCheckpointManager(BaseCheckpointManager):
    """

    A checkpoint manager that saves and loads

    - model

    - optimizer

    - lr_scheduler

    - extra_states

    in a SPMD way.



    We save

    - sharded model states and optimizer states

    - full lr_scheduler states

    - huggingface tokenizer/processor and config for ckpt merge

    """

    def __init__(

        self,

        config,

        model_config,

        role,

        model: torch.nn.ModuleList,

        arch: str,

        hf_config,

        param_dtype: torch.dtype,

        share_embeddings_and_output_weights: bool,

        tokenizer,

        optimizer,

        use_distributed_optimizer: bool,

        checkpoint_contents: Optional[list] = None,

        **kwargs,

    ):
        if checkpoint_contents is None:
            checkpoint_contents = ["model", "optimizer", "extra"]
        super().__init__(
            model,
            optimizer=optimizer,
            lr_scheduler=None,
            processing_class=tokenizer,
            checkpoint_contents=checkpoint_contents,
        )
        self.arch = arch
        self.config = config
        self.role = role
        self.is_value_model = False
        if self.role in ["reward", "critic"]:
            self.is_value_model = True
        self.model_config = model_config
        self.hf_config = hf_config
        self.param_dtype = param_dtype
        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
        self.model_path = self.config.model.path
        self.use_distributed_optimizer = use_distributed_optimizer

        self.rank = torch.distributed.get_rank()

        self.weight_saver = get_weight_saver(self.arch)

    def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: bool = False):
        """collect rng state across data parallel ranks"""
        rng_state = {
            "random_rng_state": random.getstate(),
            "np_rng_state": np.random.get_state(),
            "torch_rng_state": torch.get_rng_state(),
            "cuda_rng_state": torch.cuda.get_rng_state(),
            "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(),
        }

        rng_state_list = None
        if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init:
            rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())]
            torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group())
        else:
            rng_state_list = [rng_state]

        if use_dist_ckpt:
            pp_rank = mpu.get_pipeline_model_parallel_rank()
            pp_size = mpu.get_pipeline_model_parallel_world_size()
            tp_rank = mpu.get_tensor_model_parallel_rank()
            tp_size = mpu.get_tensor_model_parallel_world_size()
            cp_rank = mpu.get_context_parallel_rank()
            cp_size = mpu.get_context_parallel_world_size()
            rng_state_list = ShardedObject(
                "rng_state",
                rng_state_list,
                (pp_size, tp_size, cp_size),
                (pp_rank, tp_rank, cp_rank),
                replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
            )

        return rng_state_list

    def get_checkpoint_name(

        self,

        checkpoints_path,

        pipeline_parallel=None,

        tensor_rank=None,

        pipeline_rank=None,

        cp_rank=None,

        expert_parallel=None,

        expert_rank=None,

        return_base_dir=True,

        basename="model.pt",

    ):
        """Determine the directory name for this rank's checkpoint."""
        # Use both the tensor and pipeline MP rank.
        if pipeline_parallel is None:
            pipeline_parallel = mpu.get_pipeline_model_parallel_world_size() > 1
        if tensor_rank is None:
            tensor_rank = mpu.get_tensor_model_parallel_rank()
        if pipeline_rank is None:
            pipeline_rank = mpu.get_pipeline_model_parallel_rank()
        if cp_rank is None:
            cp_rank = mpu.get_context_parallel_rank()
        if expert_parallel is None:
            expert_parallel = mpu.get_expert_model_parallel_world_size() > 1
        if expert_rank is None:
            expert_rank = mpu.get_expert_model_parallel_rank()

        # Use both the tensor and pipeline MP rank. If using the distributed
        # optimizer, then the optimizer's path must additionally include the
        # data parallel rank.

        # due to the fact that models are identical across cp ranks, cp rank is not used in the checkpoint path
        if not pipeline_parallel:
            common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}")
        else:
            common_path = os.path.join(checkpoints_path, f"mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}")

        if expert_parallel:
            common_path = common_path + f"_{expert_rank:03d}"

        os.makedirs(common_path, exist_ok=True)

        if return_base_dir:
            return common_path
        return os.path.join(common_path, basename)

    def load_optimizer(self, ckpt_path):
        # TODO: Check Optimizer format and distributed optimizer
        optimizer_path = get_optimizer_checkpoint_path(ckpt_path)
        print(f"Loading optimizer from {optimizer_path}")
        self.optimizer.load_parameter_state(optimizer_path)

    def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_ckpt=False):
        rng_state_path = get_rng_states_checkpoint_path(ckpt_path, only_rank0_save=False)
        print(f"Loading rng states from {rng_state_path}")
        rng_state = torch.load(rng_state_path, weights_only=False)
        # access rng_state for data parallel rank
        if not use_dist_ckpt:
            rng_state = rng_state[mpu.get_data_parallel_rank()] if data_parallel_random_init else rng_state[0]
        random.setstate(rng_state["random_rng_state"])
        np.random.set_state(rng_state["np_rng_state"])
        torch.set_rng_state(rng_state["torch_rng_state"])
        torch.cuda.set_rng_state(rng_state["cuda_rng_state"])
        # Check for empty states array
        if not rng_state["rng_tracker_states"]:
            raise KeyError
        tensor_parallel.get_cuda_rng_tracker().set_states(rng_state["rng_tracker_states"])

    def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_after_load=False):
        if local_path is None:
            return

        if "model" in self.checkpoint_contents:
            model_path = get_model_checkpoint_path(local_path)
            ckpt_name = self.get_checkpoint_name(model_path, return_base_dir=False)
            state_dicts = torch.load(os.path.join(ckpt_name), weights_only=False)
            assert len(state_dicts) == len(self.model), f"state_dicts length: {len(state_dicts)} mismatch with model length: {len(self.model)}"
            for vpp_rank, (state_dict, model) in enumerate(zip(state_dicts, self.model)):
                model.load_state_dict(state_dict)
            print(f"Loaded sharded model checkpoint from {model_path}")

        if "optimizer" in self.checkpoint_contents:
            self.load_optimizer(local_path)

        if "extra" in self.checkpoint_contents:
            self.load_rng_states(local_path)

        if del_local_after_load:
            try:
                os.remove(local_path) if is_non_local(local_path) else None
            except Exception as e:
                print(f"[rank-{self.rank}]: remove local resume ckpt file after loading failed, exception {e} will be ignored")

    def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: int = 0, max_ckpt_to_keep=None):
        # record the previous global step
        self.previous_global_step = global_step

        # remove previous local_path
        if max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 and len(self.previous_saved_paths) >= max_ckpt_to_keep:
            keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1
            self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start])
            self.previous_saved_paths = self.previous_saved_paths[keep_start:]

        local_path = self.local_mkdir(local_path)

        # Save Model
        if "model" in self.checkpoint_contents and mpu.get_data_parallel_rank() == 0:
            state_dicts = []

            for vpp_rank, model in enumerate(self.model):
                state_dict = model.state_dict()
                state_dicts.append(state_dict)

            print(f"Saving sharded model checkpoint to {local_path}")
            model_ckpt_path = get_model_checkpoint_path(local_path)
            hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
            ckpt_name = self.get_checkpoint_name(model_ckpt_path, return_base_dir=False)
            torch.save(state_dicts, os.path.join(ckpt_name))
            self.processing_class.save_pretrained(hf_model_ckpt_path)  # tokenizer will be saved to hf_model_ckpt_path
            print(f"Saved checkpoint to {model_ckpt_path}")
            if hdfs_path is not None:
                print(f"Uploading checkpoint to {hdfs_path}")
                from verl.utils import hdfs_io

                hdfs_io.makedirs(hdfs_path, exist_ok=True)
                hdfs_io.copy(src=model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)

        if "hf_model" in self.checkpoint_contents:
            # wait for everyone to dump to local
            state_dict = self.weight_saver(
                self.model,
                self.hf_config,
                dtype=self.param_dtype,
                is_value_model=self.is_value_model,
                tie_word_embeddings=self.share_embeddings_and_output_weights,
            )

            torch.distributed.barrier()
            print(f"self.param_dtype: {self.param_dtype}")
            for key in state_dict.keys():
                print(f"state_dict[key].dtype: {key} {state_dict[key].dtype}")
            torch.distributed.barrier()
            if self.rank == 0:
                hf_model_ckpt_path = get_hf_model_checkpoint_path(local_path)
                import warnings

                from accelerate import init_empty_weights

                with init_empty_weights(), warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    if "mistral7b-rm" in self.config.model.path:
                        from transformers import MistralForSequenceClassification

                        model = MistralForSequenceClassification.from_pretrained(self.config.model.path)  # use score head instead of lm_head
                        state_dict["score.weight"] = state_dict["score.weight"]
                    else:
                        from transformers import AutoModelForCausalLM

                        model = AutoModelForCausalLM.from_pretrained(self.config.model.path, torch_dtype="auto")
                model.save_pretrained(hf_model_ckpt_path, state_dict=state_dict)
                if hdfs_path is not None:
                    print(f"Uploading checkpoint to {hdfs_path}")
                    from verl.utils import hdfs_io

                    hdfs_io.makedirs(hdfs_path, exist_ok=True)
                    hdfs_io.copy(src=hf_model_ckpt_path, dst=hdfs_path, dirs_exist_ok=True)

        # Save Optimizer
        if "optimizer" in self.checkpoint_contents:
            torch.distributed.barrier()

            optimizer_path = get_optimizer_checkpoint_path(local_path)
            self.optimizer.save_parameter_state(optimizer_path)
            if self.rank == 0:
                print(f"saving optimizer state to {optimizer_path}")

        # Save RNG States
        if "extra" in self.checkpoint_contents:
            torch.distributed.barrier()

            rng_state_path = get_rng_states_checkpoint_path(local_path, only_rank0_save=False)
            rng_state = self.get_rng_state()
            torch.save(rng_state, rng_state_path)
            print(f"Rank {self.rank} saving rng states to {rng_state_path}")

        self.previous_saved_paths.append(local_path)