|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
|
|
|
def analyze_ckpt(path): |
|
|
if not os.path.exists(path): |
|
|
print(f"File not found: {path}") |
|
|
return |
|
|
|
|
|
size_mb = os.path.getsize(path) / (1024 * 1024) |
|
|
print(f"\nAnalyzing {path} (Total File Size: {size_mb:.2f} MB)...") |
|
|
|
|
|
try: |
|
|
ckpt = torch.load(path, map_location="cpu") |
|
|
except Exception as e: |
|
|
print(f"Error loading checkpoint: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
state_dict_size = 0 |
|
|
if "state_dict" in ckpt: |
|
|
for k, v in ckpt["state_dict"].items(): |
|
|
state_dict_size += v.numel() * v.element_size() |
|
|
print(f" - Model Weights (state_dict): {state_dict_size / (1024*1024):.2f} MB") |
|
|
|
|
|
|
|
|
opt_size = 0 |
|
|
if "optimizer_states" in ckpt: |
|
|
for opt in ckpt["optimizer_states"]: |
|
|
|
|
|
if isinstance(opt, dict) and "state" in opt: |
|
|
for param_id, state in opt["state"].items(): |
|
|
for k, v in state.items(): |
|
|
if torch.is_tensor(v): |
|
|
opt_size += v.numel() * v.element_size() |
|
|
print(f" - Optimizer States: {opt_size / (1024*1024):.2f} MB") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
analyze_ckpt("s050000.ckpt") |
|
|
analyze_ckpt("./checkpoints/last_manual.ckpt") |
|
|
|