| import torch |
| import os |
| import sys |
|
|
| def analyze_ckpt(path): |
| if not os.path.exists(path): |
| print(f"File not found: {path}") |
| return |
|
|
| size_mb = os.path.getsize(path) / (1024 * 1024) |
| print(f"\nAnalyzing {path} (Total File Size: {size_mb:.2f} MB)...") |
| |
| try: |
| ckpt = torch.load(path, map_location="cpu") |
| except Exception as e: |
| print(f"Error loading checkpoint: {e}") |
| return |
| |
| |
| state_dict_size = 0 |
| if "state_dict" in ckpt: |
| for k, v in ckpt["state_dict"].items(): |
| state_dict_size += v.numel() * v.element_size() |
| print(f" - Model Weights (state_dict): {state_dict_size / (1024*1024):.2f} MB") |
| |
| |
| opt_size = 0 |
| if "optimizer_states" in ckpt: |
| for opt in ckpt["optimizer_states"]: |
| |
| if isinstance(opt, dict) and "state" in opt: |
| for param_id, state in opt["state"].items(): |
| for k, v in state.items(): |
| if torch.is_tensor(v): |
| opt_size += v.numel() * v.element_size() |
| print(f" - Optimizer States: {opt_size / (1024*1024):.2f} MB") |
|
|
| if __name__ == "__main__": |
| |
| analyze_ckpt("s050000.ckpt") |
| analyze_ckpt("./checkpoints/last_manual.ckpt") |
|
|