File size: 6,305 Bytes
54c5666 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """Recover from corrupted or incomplete checkpoints"""
import torch
import os
import argparse
from pathlib import Path
import logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def recover_checkpoint(checkpoint_dir: str, output_path: str, verbose: bool = True):
"""
Try to recover a corrupted checkpoint
Args:
checkpoint_dir: Directory containing checkpoint files
output_path: Path to save recovered checkpoint
verbose: Print detailed information
Returns:
True if recovery successful, False otherwise
"""
checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
logger.error(f"Checkpoint directory not found: {checkpoint_dir}")
return False
# Find all checkpoint files
checkpoint_files = list(checkpoint_dir.glob("*.pt")) + list(checkpoint_dir.glob("*.pth"))
if not checkpoint_files:
logger.error(f"No checkpoint files found in {checkpoint_dir}")
return False
logger.info(f"Found {len(checkpoint_files)} checkpoint files")
# Try to load checkpoints in reverse order (newest first)
checkpoint_files = sorted(checkpoint_files, key=os.path.getmtime, reverse=True)
for ckpt_file in checkpoint_files:
try:
if verbose:
logger.info(f"Attempting to load: {ckpt_file.name}")
# Try loading
checkpoint = torch.load(ckpt_file, map_location='cpu')
# Validate checkpoint structure
if not isinstance(checkpoint, dict):
logger.warning(f" ✗ Not a dictionary: {type(checkpoint)}")
continue
required_keys = ['model_state_dict']
optional_keys = ['optimizer_state_dict', 'scheduler_state_dict', 'epoch', 'global_step', 'loss']
if not all(k in checkpoint for k in required_keys):
missing = [k for k in required_keys if k not in checkpoint]
logger.warning(f" ✗ Missing required keys: {missing}")
logger.info(f" Available keys: {list(checkpoint.keys())}")
continue
# Checkpoint is valid
logger.info(f" ✓ Valid checkpoint found: {ckpt_file.name}")
# Print checkpoint info
if verbose:
logger.info(f" Checkpoint information:")
for key in optional_keys:
if key in checkpoint:
value = checkpoint[key]
if key in ['epoch', 'global_step']:
logger.info(f" {key}: {value}")
elif key == 'loss':
logger.info(f" {key}: {value:.6f}")
# Count model parameters
model_state = checkpoint['model_state_dict']
num_params = sum(v.numel() for v in model_state.values())
logger.info(f" Parameters: {num_params:,}")
# Save recovered checkpoint
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(checkpoint, output_path)
logger.info(f" ✓ Saved recovered checkpoint to: {output_path}")
return True
except Exception as e:
logger.warning(f" ✗ Failed to load {ckpt_file.name}: {e}")
continue
logger.error("✗ No valid checkpoint could be recovered")
return False
def inspect_checkpoint(checkpoint_path: str):
"""Inspect a checkpoint file"""
try:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("=" * 60)
print(f"Checkpoint: {checkpoint_path}")
print("=" * 60)
if not isinstance(checkpoint, dict):
print(f"Type: {type(checkpoint)}")
print("Not a dictionary - unexpected format")
return
print(f"Keys: {list(checkpoint.keys())}")
print()
# Model state
if 'model_state_dict' in checkpoint:
model_state = checkpoint['model_state_dict']
num_params = sum(v.numel() for v in model_state.values())
print(f"Model parameters: {num_params:,}")
print(f"Model state keys: {len(model_state)}")
# Training info
if 'epoch' in checkpoint:
print(f"Epoch: {checkpoint['epoch']}")
if 'global_step' in checkpoint:
print(f"Global step: {checkpoint['global_step']}")
if 'loss' in checkpoint:
print(f"Loss: {checkpoint['loss']:.6f}")
# Optimizer
if 'optimizer_state_dict' in checkpoint:
print("Optimizer state: Present")
# Scheduler
if 'scheduler_state_dict' in checkpoint:
print("Scheduler state: Present")
print("=" * 60)
except Exception as e:
print(f"Error loading checkpoint: {e}")
def main():
parser = argparse.ArgumentParser(description='Recover corrupted checkpoints')
parser.add_argument('--checkpoint_dir', type=str, required=True,
help='Directory containing checkpoint files')
parser.add_argument('--output', type=str, required=True,
help='Path to save recovered checkpoint')
parser.add_argument('--inspect', type=str, default=None,
help='Inspect a specific checkpoint file')
parser.add_argument('--verbose', action='store_true',
help='Print detailed information')
args = parser.parse_args()
if args.inspect:
inspect_checkpoint(args.inspect)
else:
success = recover_checkpoint(args.checkpoint_dir, args.output, args.verbose)
exit(0 if success else 1)
if __name__ == "__main__":
main()
|