| import os |
| import glob |
| import re |
| import shutil |
| from torchtitan.tools.logging import logger |
|
|
|
|
| def cleanup_local_checkpoints(checkpoint_dir: str, keep_latest_k: int): |
| """Removes older checkpoint directories locally, keeping only the latest k for both DCP and HF formats.""" |
| if keep_latest_k <= 0: |
| return |
|
|
| logger.info(f"Cleaning up local checkpoints in {checkpoint_dir}, keeping latest {keep_latest_k}") |
|
|
| |
| dcp_checkpoints = sorted( |
| glob.glob(os.path.join(checkpoint_dir, "step-*")), |
| key=lambda x: int(re.search(r"step-(\d+)", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)", os.path.basename(x)) and not x.endswith("-hf") else -1, |
| reverse=True |
| ) |
| |
| dcp_checkpoints = [d for d in dcp_checkpoints if not d.endswith("-hf")] |
|
|
| if len(dcp_checkpoints) > keep_latest_k: |
| checkpoints_to_delete = dcp_checkpoints[keep_latest_k:] |
| logger.info(f"Deleting {len(checkpoints_to_delete)} old DCP checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}") |
| for ckpt_path in checkpoints_to_delete: |
| if os.path.isdir(ckpt_path): |
| try: |
| shutil.rmtree(ckpt_path) |
| except OSError as e: |
| logger.error(f"Error removing directory {ckpt_path}: {e}") |
|
|
|
|
| |
| hf_checkpoints = sorted( |
| glob.glob(os.path.join(checkpoint_dir, "step-*-hf")), |
| key=lambda x: int(re.search(r"step-(\d+)-hf", os.path.basename(x)).group(1)) if re.search(r"step-(\d+)-hf", os.path.basename(x)) else -1, |
| reverse=True |
| ) |
|
|
| if len(hf_checkpoints) > keep_latest_k: |
| checkpoints_to_delete = hf_checkpoints[keep_latest_k:] |
| logger.info(f"Deleting {len(checkpoints_to_delete)} old HF checkpoints: {[os.path.basename(c) for c in checkpoints_to_delete]}") |
| for ckpt_path in checkpoints_to_delete: |
| if os.path.isdir(ckpt_path): |
| try: |
| shutil.rmtree(ckpt_path) |
| except OSError as e: |
| logger.error(f"Error removing directory {ckpt_path}: {e}") |
|
|