import torch import sys def inspect_checkpoint(path): print(f"Inspecting {path}") try: ckpt = torch.load(path, weights_only=False) print("Keys in checkpoint:", ckpt.keys()) if "optimizer_states" in ckpt: print("Optimizer states found: YES") else: print("Optimizer states found: NO") if "state_dict" in ckpt: print(f"Model keys: {len(ckpt['state_dict'])} items") except Exception as e: print(f"Error loading checkpoint: {e}") if __name__ == "__main__": if len(sys.argv) > 1: inspect_checkpoint(sys.argv[1]) else: print("Usage: python inspect_ckpt.py ")