import torch import os import sys def analyze_ckpt(path): if not os.path.exists(path): print(f"File not found: {path}") return size_mb = os.path.getsize(path) / (1024 * 1024) print(f"\nAnalyzing {path} (Total File Size: {size_mb:.2f} MB)...") try: ckpt = torch.load(path, map_location="cpu") except Exception as e: print(f"Error loading checkpoint: {e}") return # 1. Measure Model Weights (state_dict) state_dict_size = 0 if "state_dict" in ckpt: for k, v in ckpt["state_dict"].items(): state_dict_size += v.numel() * v.element_size() print(f" - Model Weights (state_dict): {state_dict_size / (1024*1024):.2f} MB") # 2. Measure Optimizer States opt_size = 0 if "optimizer_states" in ckpt: for opt in ckpt["optimizer_states"]: # Optimizer state structure can vary, this is a general traversal if isinstance(opt, dict) and "state" in opt: for param_id, state in opt["state"].items(): for k, v in state.items(): if torch.is_tensor(v): opt_size += v.numel() * v.element_size() print(f" - Optimizer States: {opt_size / (1024*1024):.2f} MB") if __name__ == "__main__": # Replace with your actual paths if different analyze_ckpt("s050000.ckpt") analyze_ckpt("./checkpoints/last_manual.ckpt")