UltraThinking-LLM-Training / scripts /recover_checkpoint.py
Vedisasi's picture
Upload folder using huggingface_hub
54c5666 verified
"""Recover from corrupted or incomplete checkpoints"""
import torch
import os
import argparse
from pathlib import Path
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def recover_checkpoint(checkpoint_dir: str, output_path: str, verbose: bool = True):
"""
Try to recover a corrupted checkpoint
Args:
checkpoint_dir: Directory containing checkpoint files
output_path: Path to save recovered checkpoint
verbose: Print detailed information
Returns:
True if recovery successful, False otherwise
"""
checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
logger.error(f"Checkpoint directory not found: {checkpoint_dir}")
return False
# Find all checkpoint files
checkpoint_files = list(checkpoint_dir.glob("*.pt")) + list(checkpoint_dir.glob("*.pth"))
if not checkpoint_files:
logger.error(f"No checkpoint files found in {checkpoint_dir}")
return False
logger.info(f"Found {len(checkpoint_files)} checkpoint files")
# Try to load checkpoints in reverse order (newest first)
checkpoint_files = sorted(checkpoint_files, key=os.path.getmtime, reverse=True)
for ckpt_file in checkpoint_files:
try:
if verbose:
logger.info(f"Attempting to load: {ckpt_file.name}")
# Try loading
checkpoint = torch.load(ckpt_file, map_location='cpu')
# Validate checkpoint structure
if not isinstance(checkpoint, dict):
logger.warning(f" βœ— Not a dictionary: {type(checkpoint)}")
continue
required_keys = ['model_state_dict']
optional_keys = ['optimizer_state_dict', 'scheduler_state_dict', 'epoch', 'global_step', 'loss']
if not all(k in checkpoint for k in required_keys):
missing = [k for k in required_keys if k not in checkpoint]
logger.warning(f" βœ— Missing required keys: {missing}")
logger.info(f" Available keys: {list(checkpoint.keys())}")
continue
# Checkpoint is valid
logger.info(f" βœ“ Valid checkpoint found: {ckpt_file.name}")
# Print checkpoint info
if verbose:
logger.info(f" Checkpoint information:")
for key in optional_keys:
if key in checkpoint:
value = checkpoint[key]
if key in ['epoch', 'global_step']:
logger.info(f" {key}: {value}")
elif key == 'loss':
logger.info(f" {key}: {value:.6f}")
# Count model parameters
model_state = checkpoint['model_state_dict']
num_params = sum(v.numel() for v in model_state.values())
logger.info(f" Parameters: {num_params:,}")
# Save recovered checkpoint
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(checkpoint, output_path)
logger.info(f" βœ“ Saved recovered checkpoint to: {output_path}")
return True
except Exception as e:
logger.warning(f" βœ— Failed to load {ckpt_file.name}: {e}")
continue
logger.error("βœ— No valid checkpoint could be recovered")
return False
def inspect_checkpoint(checkpoint_path: str):
"""Inspect a checkpoint file"""
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("=" * 60)
print(f"Checkpoint: {checkpoint_path}")
print("=" * 60)
if not isinstance(checkpoint, dict):
print(f"Type: {type(checkpoint)}")
print("Not a dictionary - unexpected format")
return
print(f"Keys: {list(checkpoint.keys())}")
print()
# Model state
if 'model_state_dict' in checkpoint:
model_state = checkpoint['model_state_dict']
num_params = sum(v.numel() for v in model_state.values())
print(f"Model parameters: {num_params:,}")
print(f"Model state keys: {len(model_state)}")
# Training info
if 'epoch' in checkpoint:
print(f"Epoch: {checkpoint['epoch']}")
if 'global_step' in checkpoint:
print(f"Global step: {checkpoint['global_step']}")
if 'loss' in checkpoint:
print(f"Loss: {checkpoint['loss']:.6f}")
# Optimizer
if 'optimizer_state_dict' in checkpoint:
print("Optimizer state: Present")
# Scheduler
if 'scheduler_state_dict' in checkpoint:
print("Scheduler state: Present")
print("=" * 60)
except Exception as e:
print(f"Error loading checkpoint: {e}")
def main():
parser = argparse.ArgumentParser(description='Recover corrupted checkpoints')
parser.add_argument('--checkpoint_dir', type=str, required=True,
help='Directory containing checkpoint files')
parser.add_argument('--output', type=str, required=True,
help='Path to save recovered checkpoint')
parser.add_argument('--inspect', type=str, default=None,
help='Inspect a specific checkpoint file')
parser.add_argument('--verbose', action='store_true',
help='Print detailed information')
args = parser.parse_args()
if args.inspect:
inspect_checkpoint(args.inspect)
else:
success = recover_checkpoint(args.checkpoint_dir, args.output, args.verbose)
exit(0 if success else 1)
if __name__ == "__main__":
main()