File size: 39,662 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import shutil
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Iterable, List, Literal, Optional, Union

import lightning
import lightning.pytorch as pl
import torch
from _weakref import proxy
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint
from lightning.pytorch.callbacks.model_checkpoint import _is_local_file_protocol
from lightning.pytorch.utilities import rank_zero_info

from nemo.lightning.ckpt_utils import ckpt_to_dir
from nemo.lightning.io.pl import TrainerContext
from nemo.utils import logging
from nemo.utils.app_state import AppState


class ModelCheckpoint(PTLModelCheckpoint):
    """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end.
    Adds support for asyncronous checkpointing and provides some additional logic to clean up invalid checkpoints

    Args:
        monitor: Metric to monitor when saving top-k checkpoints.
        verbose: Verbosity mode.
        save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved.
        save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``.
        save_weights_only:  if ``True``, then only the model's weights will be saved. Optimizer states will
            be omitted from all checkpoints.
        mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity.
        every_n_epochs: Number of epochs between checkpoints.
        every_n_train_steps: Number of train steps between checkpoints.
        train_time_interval: After each interval, monitor checkpoints. Not to be used with
            ``every_n_epochs`` or ``every_n_train_steps``.
        save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch
        save_optim_on_train_end: Whether to include the optimizer states in the final checkpoint
            at the end of training. Only applicable when save_weights_only is ``False``.
        always_save_context: Whether to dump the artifacts needed to reinintialize the current
            model, trainer, and dataloader to allow for reproducibility of experiments.
        save_context_on_train_end: Whether to dump the artifacts on_train_end regardless of whether
            ``always_save_context`` is ``True``.
        async_save: Whether to enable asynchronous checkpointing.

    Attributes:
        UNFINISHED_CHECKPOINT_SUFFIX (str): Suffix for unfinished checkpoint files.
        deferred_ckpts_to_remove (List[List[str]]): List of deferred checkpoints
            to remove once async save is completed.
        ckpts_to_link (Dict[str, str]): Dictionary of checkpoint paths that need to be symlinked.
        future_last_model_path (str): Path to the future 'last' checkpoint, used for symbolic linking.
        best_k_models (dict): Dictionary of best-k checkpoints based on the monitored metric.
        best_model_score (float): Score of the best checkpoint.
        best_model_path (str): Path to the best checkpoint.
        kth_best_model_path (str): Path to the kth best checkpoint.
    """

    UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"

    def __init__(
        self,
        monitor: Optional[str] = "val_loss",
        verbose: bool = True,
        save_last: Optional[Union[bool, Literal["link"]]] = True,
        save_top_k: int = 3,
        save_weights_only: bool = False,  # TODO: check support
        mode: str = "min",
        every_n_epochs: int = None,
        every_n_train_steps: Optional[int] = None,
        train_time_interval: Optional[timedelta] = None,
        # Save after training, not after validation
        save_on_train_epoch_end: Optional[bool] = False,
        save_optim_on_train_end: Optional[bool] = False,
        always_save_context: bool = True,
        save_context_on_train_end: bool = True,
        **kwargs,
    ):
        self.always_save_context = always_save_context
        self.save_context_on_train_end = save_context_on_train_end
        self.save_optim_on_train_end = save_optim_on_train_end

        # stores the next -last checkpoint to be saved, used only when save_last = 'link'
        # this is needed because when using symlinks, we need to update the non-last checkpoint's
        # last_model_path to point to the corresponding -last version
        self.future_last_model_path = ""

        # Checkpoints which removal is deferred until async save is done.
        # Each element of `deferred_ckpts_to_remove` is a growing list
        # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
        # is called, the last element is frozen and a new element is added.
        self.deferred_ckpts_to_remove: List[List[str]] = []
        self.ckpts_to_link: Dict[str, str] = {}

        # Call the parent class constructor with the remaining kwargs.
        super().__init__(
            monitor=monitor,
            verbose=verbose,
            save_last=save_last,
            save_top_k=save_top_k,
            save_weights_only=save_weights_only,
            mode=mode,
            every_n_epochs=every_n_epochs,
            every_n_train_steps=every_n_train_steps,
            train_time_interval=train_time_interval,
            save_on_train_epoch_end=save_on_train_epoch_end,
            **kwargs,
        )

    def on_train_start(self, trainer, pl_module):
        """
        Initializes checkpointing by handling previous runs,
        setting up file logging, and managing files to move or copy.

        This method handles:
        - Moving old files to new folders
        - Copying relevant files to the log directory
        - Creating command argument and git information logs
        - Setting up logging for errors and Lightning logs

        Args:
            trainer (pl.Trainer): The PyTorch Lightning trainer object.
            pl_module (pl.LightningModule): The Lightning model to be trained.
        """
        from nemo.utils.exp_manager import get_git_diff, get_git_hash
        from nemo.utils.get_rank import is_global_rank_zero
        from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger

        app_state = AppState()
        if self.save_top_k != -1 and app_state.restore:
            logging.debug("Checking previous runs")
            self.nemo_topk_check_previous_run()

        if is_global_rank_zero():
            log_dir = app_state.log_dir

            # Check to see if any files exist that need to be moved
            files_to_move = app_state.files_to_move

            if len(files_to_move) > 0:
                # Move old files to a new folder
                other_run_dirs = Path(log_dir).glob("run_*")
                run_count = 0
                for fold in other_run_dirs:
                    if fold.is_dir():
                        run_count += 1
                new_run_dir = Path(Path(log_dir) / f"run_{run_count}")
                if not new_run_dir.exists():
                    new_run_dir.mkdir()
                    for _file in files_to_move:
                        shutil.move(str(_file), str(new_run_dir))

            # Move files_to_copy to folder and add git information if present
            if app_state.files_to_copy:
                for _file in app_state.files_to_copy:
                    src_path = Path(_file)
                    dst_path = Path(log_dir) / src_path.name
                    if not dst_path.exists():
                        shutil.copy(src_path, dst_path)

            # Create files for cmd args and git info
            if app_state.cmd_args:
                cmd_args_file = log_dir / 'cmd-args.log'
                if not cmd_args_file.exists():
                    with open(cmd_args_file, 'w', encoding='utf-8') as _file:
                        _file.write(" ".join(app_state.cmd_args))

            # Try to get git hash
            git_repo, git_hash = get_git_hash()
            if git_repo:
                git_info_file = log_dir / 'git-info.log'
                if not git_info_file.exists():
                    with open(git_info_file, 'w', encoding='utf-8') as _file:
                        _file.write(f'commit hash: {git_hash}\n')
                        _file.write(get_git_diff())

            # Add err_file logging to global_rank zero
            logging.add_err_file_handler(log_dir / 'nemo_error_log.txt')

            # Add lightning file logging to global_rank zero
            add_filehandlers_to_pl_logger(log_dir / 'lightning_logs.txt', log_dir / 'nemo_error_log.txt')
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

        super().on_train_start(trainer, pl_module)

    def nemo_topk_check_previous_run(self):
        """
        Verifies and cleans up the top-k checkpoint state from previous training runs.

        This method ensures that:
        - The top-k models are correctly loaded and ordered.
        - Any outdated or invalid checkpoints are removed.
        - The best model is determined based on the monitored metric.

        Raises:
            AttributeError: If the expected attributes for the top-k model are not found.
        """
        try:
            self.best_k_models
            self.kth_best_model_path
            self.best_model_score
            self.best_model_path
        except AttributeError:
            raise AttributeError(
                "Lightning's ModelCheckpoint was updated. NeMo's ModelCheckpoint will need an update."
            )
        self.best_k_models = {}
        self.kth_best_model_path = ""
        self.best_model_score = None
        self.best_model_path = ""

        checkpoints = list(path for path in self._saved_checkpoint_paths if not self._is_ema_filepath(path))
        for checkpoint in checkpoints:
            checkpoint = str(checkpoint)
            if checkpoint[-10:] == '-last.ckpt' or checkpoint[-5:] == '-last':
                continue
            # Find monitor in str + 1 for '='
            index = checkpoint.find(self.monitor) + len(self.monitor) + 1
            if index != len(self.monitor):
                match = re.search('[A-z]', checkpoint[index:])
                if match:
                    # -1 due to separator hyphen
                    value = checkpoint[index : index + match.start() - 1]
                else:
                    value = checkpoint[index:]
                self.best_k_models[checkpoint] = float(value)
        if len(self.best_k_models) < 1:
            return  # No saved checkpoints yet

        _reverse = False if self.mode == "min" else True

        best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse)

        # This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are
        # instantiated after rank zero. models_to_delete should be 0 for all other ranks.
        models_to_delete = len(best_k_models) - self.save_top_k
        models_to_delete = max(0, models_to_delete)
        logging.debug(f'Number of models to delete: {models_to_delete}')

        # If EMA enabled, delete the additional EMA weights
        ema_enabled = self._has_ema_ckpts(self._saved_checkpoint_paths)

        for _ in range(models_to_delete):
            model = best_k_models.pop(-1)
            self.best_k_models.pop(model)
            self._del_model_without_trainer(model)
            if ema_enabled and self._fs.exists(self._ema_format_filepath(model)):
                self._del_model_without_trainer(self._ema_format_filepath(model))
            logging.debug(f"Removed checkpoint: {model}")

        self.kth_best_model_path = best_k_models[-1]
        self.best_model_path = best_k_models[0]
        self.best_model_score = self.best_k_models[self.best_model_path]

    def _remove_invalid_entries_from_topk(self):
        """
        Removes invalid (incomplete or non-existing) checkpoints from the list of top-k checkpoints.

        This function is necessary when checkpointing might have been abruptly interrupted, leaving behind
        incomplete or corrupted checkpoints. The invalid checkpoints are identified by checking if their
        corresponding directory exists and if the checkpoint is not unfinished.

        After removing invalid entries, the method updates the best-k models based on the existing, valid checkpoints.

        Attributes Updated:
            - `best_k_models`: A dictionary of valid checkpoints from top-k models.
            - `best_model_path`: Path to the best model based on the current sorting order.
            - `best_model_score`: The score associated with the best model.
            - `kth_best_model_path`: Path to the kth best model.
            - `kth_value`: The score associated with the kth best model.
        """

        # Removes invalid (incomplete or not existing) checkpoints from topk checkpoints.
        # This might be needed if the checkpointing was abruptly terminated.
        def __is_ckpt_ok(ckpt_path: str) -> bool:
            exists = os.path.isdir(ckpt_path.removesuffix('.ckpt'))
            return exists and not self.is_checkpoint_unfinished(ckpt_path)

        self.best_k_models = {k: v for k, v in self.best_k_models.items() if __is_ckpt_ok(k)}
        if len(self.best_k_models) > 0:
            reverse_arr = self.mode != "min"
            best_k_models_arr = sorted(self.best_k_models, key=self.best_k_models.get, reverse=reverse_arr)
            self.kth_best_model_path = best_k_models_arr[-1]
            self.kth_value = self.best_k_models[self.kth_best_model_path]
            self.best_model_path = best_k_models_arr[0]
            self.best_model_score = self.best_k_models[self.best_model_path]
        else:
            self.kth_best_model_path = ""
            self.kth_value = None
            self.best_model_path = ""
            self.best_model_score = None

    def state_dict(self):
        """
        Returns the state dictionary of the model.

        This function adds additional logic to handle the case when using symlinks. If the model is configured
        to save the last checkpoint as a symlink, the path to the last checkpoint is updated in the returned
        state dictionary to avoid off-by-one errors in the checkpointing system.

        Returns:
            Dict[str, Any]: The state dictionary of the model, including any necessary modifications for symlinks.
        """
        state = super().state_dict()
        # if using symlinks, overwrite last_model_path to avoid off-by-one issues
        if self.save_last == "link":
            state["last_model_path"] = self.future_last_model_path
        return state

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """
        Loads the state dictionary into the model and removes invalid entries from the top-k checkpoints.

        This method ensures that after loading the model state, any invalid (incomplete or missing) checkpoints
        are removed from the top-k models list.

        Args:
            state_dict (Dict[str, Any]): The state dictionary to load into the model.
        """
        super().load_state_dict(state_dict)
        self._remove_invalid_entries_from_topk()

    def setup(self, trainer, *args, **kwargs) -> None:
        """
        Initializes the model and removes any unfinished checkpoints before training.

        This method is responsible for ensuring that unfinished checkpoints are removed prior to starting the training.
        It also synchronizes all ranks in a distributed setting to ensure that unfinished checkpoints are removed
        across all ranks.

        Args:
            trainer: The trainer instance used for training.
            *args: Additional arguments passed to the parent setup method.
            **kwargs: Additional keyword arguments passed to the parent setup method.
        """
        from nemo.utils.get_rank import is_global_rank_zero

        if is_global_rank_zero():
            logging.debug("Removing unfinished checkpoints if any...")
            ModelCheckpoint._remove_unfinished_checkpoints(self.dirpath)
        # Ensure that all ranks continue with unfinished checkpoints removed
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

        self.async_save = getattr(trainer.strategy, "async_save", False)
        super().setup(trainer, *args, **kwargs)

    def on_train_end(self, trainer, pl_module):
        """
        Handles actions to be performed when training ends, such as saving the last checkpoint.

        This method ensures that the last checkpoint is saved if needed, particularly when validation steps
        aren't always run based on the interval. It also manages saving the training context to disk, if configured.

        Args:
            trainer: The trainer instance used for training.
            pl_module: The model being trained.
        """
        from nemo.utils.get_rank import is_global_rank_zero

        if trainer.fast_dev_run:
            return None

        # check if we need to save a last checkpoint manually as validation isn't always run based on the interval
        if self.save_last and trainer.val_check_interval != 0:
            should_save_last_checkpoint = False
            if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0:
                should_save_last_checkpoint = True
            if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0:
                should_save_last_checkpoint = True
            if should_save_last_checkpoint:
                monitor_candidates = self._monitor_candidates(trainer)
                if self.last_model_path == self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST):
                    logging.debug(f'Last checkpoint {self.last_model_path} already saved')
                else:
                    super()._save_last_checkpoint(trainer, monitor_candidates)
            if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero():
                TrainerContext.from_trainer(trainer).io_dump(
                    ckpt_to_dir(self.last_model_path) / "context", yaml_attrs=["model"]
                )
        # Call parent on_train_end() to save the -last checkpoint
        super().on_train_end(trainer, pl_module)

    def _del_model_without_trainer(self, filepath: str) -> None:
        """
        Deletes the checkpoint model directory from distributed storage without requiring the trainer.

        This method ensures that distributed checkpoints are properly removed when necessary, especially
        if the model file is no longer needed or is incomplete. The removal only happens on the rank-zero process.

        Args:
            filepath (str): The path to the checkpoint model file to be deleted.
        """

        from nemo.utils.get_rank import is_global_rank_zero

        filepath = Path(filepath)

        if is_global_rank_zero():
            try:
                dist_ckpt = ckpt_to_dir(filepath)
                shutil.rmtree(dist_ckpt, ignore_errors=True)
                logging.info(f"Removed distributed checkpoint: {dist_ckpt}")
            except:
                logging.info(f"Tried to remove distributed checkpoint: {dist_ckpt} but failed.")
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

    def _ema_callback(self, trainer: 'lightning.pytorch.Trainer'):
        """
        Retrieves the Exponential Moving Average (EMA) callback from the list of trainer callbacks.

        This method scans through the list of callbacks attached to the trainer and returns the EMA callback
        instance if present. The EMA callback is often used to track the exponential moving average of model parameters
        during training.

        Args:
            trainer ('lightning.pytorch.Trainer'): The trainer instance.

        Returns:
            EMA: The EMA callback instance if found, or None if not present.
        """
        from nemo.collections.common.callbacks import EMA

        ema_callback = None
        for callback in trainer.callbacks:
            if isinstance(callback, EMA):
                ema_callback = callback
        return ema_callback

    @staticmethod
    def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) -> Path:
        """Format the path to the unfinished checkpoint marker file.

        If the marker file exists, corresponding checkpoint is considered unfinished/incomplete.
        NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint.

        Args:
            checkpoint_path: Path to the checkpoint file or dir.
              Does not need to exist.

        Returns:
            Path to the unfinished checkpoint marker file.
        """
        marker_filepath = str(checkpoint_path).removesuffix(".ckpt")
        marker_filepath = marker_filepath.removesuffix("-EMA")
        return Path(marker_filepath + ModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX)

    @staticmethod
    def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool:
        """Check if the checkpoint is unfinished.

        Args:
            checkpoint_path: Path to the checkpoint file or dir.
              Does not need to exist.

        Returns:
            True if the checkpoint is unfinished, False otherwise.
        """
        return ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path).exists()

    @staticmethod
    def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_after=False) -> None:
        """Marks given checkpoint as unfinished.

        Args:
            checkpoint_filepath: Path to the checkpoint file or dir.
              Does not need to exist.
            barrier_after: Synchronize ranks after writing the marker file.
              Defaults to False.
        """
        from nemo.utils.get_rank import is_global_rank_zero

        if is_global_rank_zero():
            marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path)
            marker_path.parent.mkdir(parents=True, exist_ok=True)
            marker_path.touch()
        if barrier_after and torch.distributed.is_initialized():
            torch.distributed.barrier()

    @staticmethod
    def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_before=False) -> None:
        """Clear unfinished marker for given checkpoint.

        Args:
            checkpoint_path: Path to the checkpoint file or dir.
              Does not need to exist.
            barrier_before: Synchronize ranks before removing the marker file.
              Defaults to False.
        """
        from nemo.utils.get_rank import is_global_rank_zero

        try:
            if barrier_before and torch.distributed.is_initialized():
                torch.distributed.barrier()
            if is_global_rank_zero():
                marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path)
                if marker_path.exists():
                    marker_path.unlink()
        except:
            return

    def file_exists(self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True) -> bool:
        """Checks if a file or a file without a suffix (distributed checkpoint) exists."""
        exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(str(ckpt_to_dir(filepath))))
        return trainer.strategy.broadcast(exists)

    def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, torch.Tensor]:
        """Broadcast loss from last pipeline stage."""
        monitor_candidates = super()._monitor_candidates(trainer)

        from nemo.lightning._strategy_lib import _sync_from_last_pipeline_stage
        from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy

        keys = re.findall(r"[\{](.*?)[:\}]", self.filename)
        for loss_name in ['reduced_train_loss']:
            if loss_name in keys or loss_name == self.monitor:
                if loss_name not in monitor_candidates:
                    monitor_candidates[loss_name] = torch.tensor(0.0, device=torch.cuda.current_device())
                if isinstance(trainer.strategy, MegatronStrategy):
                    _sync_from_last_pipeline_stage(monitor_candidates[loss_name], broadcast=True)

        return monitor_candidates

    def _link_checkpoint(self, trainer: "pl.Trainer", filepath: str, linkpath: str, override_async=False) -> None:
        """Check to see whether this step has already been saved as top_k
        in which case we can create a symlink
        otherwise, we have to save the checkpoint
        """
        saved_current_step = str(ckpt_to_dir(linkpath)).replace("-last", "") == str(ckpt_to_dir(filepath))
        if not saved_current_step:
            self._save_checkpoint(trainer, linkpath)
            return

        # linking will happen as part of the finalize fn
        if self.async_save and not override_async:
            self.ckpts_to_link[str(filepath)] = str(linkpath)
            return

        filepath = ckpt_to_dir(filepath)
        linkpath = ckpt_to_dir(linkpath)
        super()._link_checkpoint(trainer, filepath, linkpath)

    def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None:
        """Saves the checkpoint to the given filepath

        Args:
            trainer (lightning.pytorch.Trainer): the trainer obj
            filepath (str): path to save checkpoint to.

        Raises:
            ValueError: (mcore) async_save with EMA not supported
            ValueError: (mcore) Async save requires async compatible CheckpointIO
        """

        from nemo.utils.get_rank import is_global_rank_zero

        # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
        # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
        self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
        ema_callback = self._ema_callback(trainer)

        self._last_global_step_saved = trainer.global_step

        # manually update last_model_path so symlink is up-to-date
        # should only be done when using a symlink
        if self.save_last == "link":
            self.future_last_model_path = str(ckpt_to_dir(filepath))
            if not str(ckpt_to_dir(filepath)).endswith("last"):
                self.future_last_model_path += "-last.ckpt"

        if ema_callback is not None:
            if self.async_save:
                raise ValueError('async_save with EMA not supported')
            with ema_callback.save_original_optimizer_state(trainer):
                super()._save_checkpoint(trainer, filepath)

            # save EMA copy of the model as well.
            with ema_callback.save_ema_model(trainer):
                rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
                filepath = self._ema_format_filepath(filepath)
                if self.verbose:
                    rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
                super()._save_checkpoint(trainer, filepath)
            self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
        else:
            # Determine whether to include optimizer states in the checkpoint
            # optimizer states are included when
            # 1. save_weights_only is False and
            # 2. either save_optim_on_train_end is True, or save_optim_on_train_end is False but the checkpoint
            #    is an intermediate checkpoint.
            save_weights_only = self.save_weights_only or (
                not self.save_optim_on_train_end and trainer.global_step == trainer.max_steps
            )

            # Async save passes the finalization function to checkpoint_io,
            # sync save calls the finalization function immediately after save.
            finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
            if self.async_save:
                checkpoint_io = trainer.strategy.checkpoint_io
                from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO

                if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO):
                    raise ValueError('Async save requires async compatible CheckpointIO')
                storage_options = dict(finalize_fn=finalize_fn)
                # Each upcoming ckpt removal request will be executed as part of this save finalization
                self.deferred_ckpts_to_remove.append([])
            else:
                storage_options = None
            trainer.save_checkpoint(filepath, save_weights_only, storage_options=storage_options)

            if self.always_save_context and is_global_rank_zero():
                TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context", yaml_attrs=["model"])

            if self.async_save:
                self._last_checkpoint_saved = filepath
                logging.info(f'Scheduled async checkpoint save for {filepath}')
            else:
                finalize_fn()

    def _get_finalize_save_checkpoint_callback(
        self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int
    ):
        """Creates a callback that can be used to finalize async (and sync) ckpt saves."""

        def _cb():
            logging.debug(f'Finalize callback called for step {global_step}, filepath {filepath}')
            self._last_checkpoint_saved = filepath

            # notify loggers
            if trainer.is_global_zero:
                for logger in trainer.loggers:
                    logger.after_save_checkpoint(proxy(self))

            # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker
            # we don't want to remove the marker until all checkpointing is done.
            self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)

            if not self.async_save:
                return

            logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.')

            if str(filepath) in self.ckpts_to_link:
                self._link_checkpoint(trainer, filepath, self.ckpts_to_link.pop(filepath), override_async=True)

            # Remove checkpoints marked for removal by `self._remove_checkpoint`
            # For each finalization there is exactly one entry in self.deferred_ckpts_to_remove
            assert self.deferred_ckpts_to_remove
            ckpts_to_remove = self.deferred_ckpts_to_remove.pop(0)
            logging.debug(f'Checkpoints to remove: {ckpts_to_remove}')
            for ckpt_to_remove in ckpts_to_remove:
                self._remove_checkpoint(trainer, ckpt_to_remove, override_async=True)

        return _cb

    def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False) -> None:
        """Performs checkpoint removal.

        With async save, `self._remove_checkpoint` is called before the checkpoint
        is actually finished so we can't remove it. Instead we add it to
        `self.deferred_ckpts_to_remove` for future removal.
        """
        if self.async_save and not override_async:
            # Register checkpoint removal in the last (active) checkpoint removal list
            if len(self.deferred_ckpts_to_remove) == 0:
                self.deferred_ckpts_to_remove.append([])
            self.deferred_ckpts_to_remove[-1].append(filepath)
            return
        # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
        # if anything goes wrong during removal, we should be able to detect that data is incomplete.
        self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
        try:
            super()._remove_checkpoint(trainer, filepath)
        except Exception as e:
            logging.warning(
                f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}'
            )
        ema_callback = self._ema_callback(trainer)
        if ema_callback is not None:
            # remove EMA copy of the state dict as well.

            filepath = self._ema_format_filepath(filepath)
            try:
                super()._remove_checkpoint(trainer, filepath)
            except Exception as e:
                logging.warning(
                    f'Error removing checkpoint, common if doing manual cleanup and restarting: {filepath}: {e}'
                )
        # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker
        # we don't want to remove the marker until the checkpoint is actually removed.
        self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)

    def _ema_format_filepath(self, filepath: str) -> str:
        """Formats given path for EMA checkpoint

        Args:
            filepath (str): filepath

        Returns:
            str: EMA-formatted filepath
        """
        return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}')

    def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool:
        """Checkes whether filepaths are EMA-formatted

        Args:
            checkpoints (Iterable[Path]): paths to check

        Returns:
            bool: True indicates path is EMA-formatted.
        """
        return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints)

    def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool:
        """Checkes whether filepaths are EMA-formatted

        Args:
            filepath (Union[Path, str]): path to check

        Returns:
            bool: True indicates path is EMA-formatted.
        """
        return str(filepath).endswith(f'-EMA{self.FILE_EXTENSION}')

    @property
    def _saved_checkpoint_paths(self) -> Iterable[Path]:
        """
        Retrieves a list of saved checkpoint paths while filtering out unfinished checkpoints.

        - If distributed checkpoints (directories) exist, return only those.
        - Otherwise, return individual checkpoint files with a .ckpt extension.
        - Filters out any checkpoints that are marked as unfinished.

        Returns:
            Iterable[Path]: An iterable containing valid checkpoint paths.
        """
        # distributed checkpoints are directories so we check for them here
        # we filter out unfinished checkpoints, these should be deleted during next cleanup
        dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()]
        if dist_checkpoints:
            return filter(lambda p: not self.is_checkpoint_unfinished(p), dist_checkpoints)
        else:
            checkpoint_files = [f for f in Path(self.dirpath).rglob("*.ckpt")]
            return filter(lambda p: not self.is_checkpoint_unfinished(p), checkpoint_files)

    @staticmethod
    def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:
        """
        Removes all unfinished checkpoints and their associated marker files from the filesystem.

        - Ensures this function runs only on rank 0.
        - Deletes individual unfinished checkpoint files.
        - Removes directories corresponding to unfinished distributed checkpoints.
        - Deletes the marker files indicating unfinished checkpoints.

        Args:
            checkpoint_dir (Union[Path, str]): Path to the directory containing checkpoints.

        Raises:
            AssertionError: If the function is called from a non-rank 0 process.
        """
        from nemo.utils.get_rank import is_global_rank_zero

        # Delete unfinished checkpoints from the filesystems.
        # "Unfinished marker" files are removed as well.

        if not is_global_rank_zero():
            raise AssertionError("_remove_unfinished_checkpoints should run only on rank 0")

        checkpoint_dir = Path(checkpoint_dir)

        existing_marker_filepaths = {
            f.resolve() for f in checkpoint_dir.glob(f"*{ModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}") if f.is_file()
        }

        checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")}
        for filepath in checkpoint_filepaths:
            possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(filepath)
            if possible_marker_path in existing_marker_filepaths:
                logging.warning(f'Removing unfinished checkpoint: {filepath}')
                os.remove(filepath)

        # some directories might be distributed checkpoints, we remove these if they have a unfinished marker
        all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()}
        for ckpt_dirpath in all_dirpaths:
            possible_marker_path = ModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath)
            if possible_marker_path in existing_marker_filepaths:
                logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}')
                shutil.rmtree(ckpt_dirpath)

        # delete markers
        for marker_path in existing_marker_filepaths:
            os.remove(marker_path)

    def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, current: str) -> bool:
        """Checks if the previous checkpoint should be deleted.
        A checkpoint won't be deleted if any of the cases apply:
        - The previous checkpoint is the same as the current checkpoint (means the old was already overwritten by new)
        - The previous checkpoint is not in the current checkpoint directory and the filesystem is local
        - The previous checkpoint is the checkpoint the Trainer resumed from and the filesystem is local
            and the resumed from checkpoint is not the last checkpoint
        """
        if previous == current:
            return False
        if not _is_local_file_protocol(previous):
            return True
        previous = Path(previous).absolute()
        resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None

        if resume_path is not None and previous == resume_path:
            if str(current).endswith("-last.ckpt") and resume_path.name.endswith("-last.ckpt"):
                # delete the previous `-last.ckpt` checkpoint when current saved checkpoint
                # is also `-last.ckpt`, if they're in the same directory
                pass
            else:
                return False
        if self.dirpath is None:
            raise ValueError(f"{self.__class__}.dirpath is None.")
        dirpath = Path(self.dirpath).absolute()
        return dirpath in previous.parents