|
|
"""Recover from corrupted or incomplete checkpoints"""
|
|
|
import torch
|
|
|
import os
|
|
|
import argparse
|
|
|
from pathlib import Path
|
|
|
import logging
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
|
)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
def recover_checkpoint(checkpoint_dir: str, output_path: str, verbose: bool = True):
|
|
|
"""
|
|
|
Try to recover a corrupted checkpoint
|
|
|
|
|
|
Args:
|
|
|
checkpoint_dir: Directory containing checkpoint files
|
|
|
output_path: Path to save recovered checkpoint
|
|
|
verbose: Print detailed information
|
|
|
|
|
|
Returns:
|
|
|
True if recovery successful, False otherwise
|
|
|
"""
|
|
|
checkpoint_dir = Path(checkpoint_dir)
|
|
|
|
|
|
if not checkpoint_dir.exists():
|
|
|
logger.error(f"Checkpoint directory not found: {checkpoint_dir}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
checkpoint_files = list(checkpoint_dir.glob("*.pt")) + list(checkpoint_dir.glob("*.pth"))
|
|
|
|
|
|
if not checkpoint_files:
|
|
|
logger.error(f"No checkpoint files found in {checkpoint_dir}")
|
|
|
return False
|
|
|
|
|
|
logger.info(f"Found {len(checkpoint_files)} checkpoint files")
|
|
|
|
|
|
|
|
|
checkpoint_files = sorted(checkpoint_files, key=os.path.getmtime, reverse=True)
|
|
|
|
|
|
for ckpt_file in checkpoint_files:
|
|
|
try:
|
|
|
if verbose:
|
|
|
logger.info(f"Attempting to load: {ckpt_file.name}")
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(ckpt_file, map_location='cpu')
|
|
|
|
|
|
|
|
|
if not isinstance(checkpoint, dict):
|
|
|
logger.warning(f" β Not a dictionary: {type(checkpoint)}")
|
|
|
continue
|
|
|
|
|
|
required_keys = ['model_state_dict']
|
|
|
optional_keys = ['optimizer_state_dict', 'scheduler_state_dict', 'epoch', 'global_step', 'loss']
|
|
|
|
|
|
if not all(k in checkpoint for k in required_keys):
|
|
|
missing = [k for k in required_keys if k not in checkpoint]
|
|
|
logger.warning(f" β Missing required keys: {missing}")
|
|
|
logger.info(f" Available keys: {list(checkpoint.keys())}")
|
|
|
continue
|
|
|
|
|
|
|
|
|
logger.info(f" β Valid checkpoint found: {ckpt_file.name}")
|
|
|
|
|
|
|
|
|
if verbose:
|
|
|
logger.info(f" Checkpoint information:")
|
|
|
for key in optional_keys:
|
|
|
if key in checkpoint:
|
|
|
value = checkpoint[key]
|
|
|
if key in ['epoch', 'global_step']:
|
|
|
logger.info(f" {key}: {value}")
|
|
|
elif key == 'loss':
|
|
|
logger.info(f" {key}: {value:.6f}")
|
|
|
|
|
|
|
|
|
model_state = checkpoint['model_state_dict']
|
|
|
num_params = sum(v.numel() for v in model_state.values())
|
|
|
logger.info(f" Parameters: {num_params:,}")
|
|
|
|
|
|
|
|
|
output_path = Path(output_path)
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
torch.save(checkpoint, output_path)
|
|
|
logger.info(f" β Saved recovered checkpoint to: {output_path}")
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f" β Failed to load {ckpt_file.name}: {e}")
|
|
|
continue
|
|
|
|
|
|
logger.error("β No valid checkpoint could be recovered")
|
|
|
return False
|
|
|
|
|
|
|
|
|
def inspect_checkpoint(checkpoint_path: str):
|
|
|
"""Inspect a checkpoint file"""
|
|
|
try:
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
|
print("=" * 60)
|
|
|
print(f"Checkpoint: {checkpoint_path}")
|
|
|
print("=" * 60)
|
|
|
|
|
|
if not isinstance(checkpoint, dict):
|
|
|
print(f"Type: {type(checkpoint)}")
|
|
|
print("Not a dictionary - unexpected format")
|
|
|
return
|
|
|
|
|
|
print(f"Keys: {list(checkpoint.keys())}")
|
|
|
print()
|
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint:
|
|
|
model_state = checkpoint['model_state_dict']
|
|
|
num_params = sum(v.numel() for v in model_state.values())
|
|
|
print(f"Model parameters: {num_params:,}")
|
|
|
print(f"Model state keys: {len(model_state)}")
|
|
|
|
|
|
|
|
|
if 'epoch' in checkpoint:
|
|
|
print(f"Epoch: {checkpoint['epoch']}")
|
|
|
|
|
|
if 'global_step' in checkpoint:
|
|
|
print(f"Global step: {checkpoint['global_step']}")
|
|
|
|
|
|
if 'loss' in checkpoint:
|
|
|
print(f"Loss: {checkpoint['loss']:.6f}")
|
|
|
|
|
|
|
|
|
if 'optimizer_state_dict' in checkpoint:
|
|
|
print("Optimizer state: Present")
|
|
|
|
|
|
|
|
|
if 'scheduler_state_dict' in checkpoint:
|
|
|
print("Scheduler state: Present")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error loading checkpoint: {e}")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description='Recover corrupted checkpoints')
|
|
|
parser.add_argument('--checkpoint_dir', type=str, required=True,
|
|
|
help='Directory containing checkpoint files')
|
|
|
parser.add_argument('--output', type=str, required=True,
|
|
|
help='Path to save recovered checkpoint')
|
|
|
parser.add_argument('--inspect', type=str, default=None,
|
|
|
help='Inspect a specific checkpoint file')
|
|
|
parser.add_argument('--verbose', action='store_true',
|
|
|
help='Print detailed information')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if args.inspect:
|
|
|
inspect_checkpoint(args.inspect)
|
|
|
else:
|
|
|
success = recover_checkpoint(args.checkpoint_dir, args.output, args.verbose)
|
|
|
exit(0 if success else 1)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|