| # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| def save_checkpoint(args, accelerator, global_step, logger): | |
| output_dir = args.output_dir | |
| # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` | |
| if accelerator.is_main_process and args.checkpoints_total_limit is not None: | |
| checkpoints = os.listdir(output_dir) | |
| checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] | |
| checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) | |
| # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints | |
| if len(checkpoints) >= args.checkpoints_total_limit: | |
| num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 | |
| removing_checkpoints = checkpoints[0:num_to_remove] | |
| logger.info( | |
| f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | |
| ) | |
| logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") | |
| for removing_checkpoint in removing_checkpoints: | |
| removing_checkpoint = os.path.join(output_dir, removing_checkpoint) | |
| shutil.rmtree(removing_checkpoint) | |
| save_path = Path(output_dir) / f"checkpoint-{global_step}" | |
| accelerator.save_state(save_path) | |
| logger.info(f"Saved state to {save_path}") | |