File size: 1,463 Bytes
fbb20ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import os
import sys

def analyze_ckpt(path):
    if not os.path.exists(path):
        print(f"File not found: {path}")
        return

    size_mb = os.path.getsize(path) / (1024 * 1024)
    print(f"\nAnalyzing {path} (Total File Size: {size_mb:.2f} MB)...")
    
    try:
        ckpt = torch.load(path, map_location="cpu")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return
    
    # 1. Measure Model Weights (state_dict)
    state_dict_size = 0
    if "state_dict" in ckpt:
        for k, v in ckpt["state_dict"].items():
            state_dict_size += v.numel() * v.element_size()
    print(f"  - Model Weights (state_dict):   {state_dict_size / (1024*1024):.2f} MB")
    
    # 2. Measure Optimizer States
    opt_size = 0
    if "optimizer_states" in ckpt:
        for opt in ckpt["optimizer_states"]:
            # Optimizer state structure can vary, this is a general traversal
            if isinstance(opt, dict) and "state" in opt:
                for param_id, state in opt["state"].items():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            opt_size += v.numel() * v.element_size()
    print(f"  - Optimizer States:             {opt_size / (1024*1024):.2f} MB")

if __name__ == "__main__":
    # Replace with your actual paths if different
    analyze_ckpt("s050000.ckpt")
    analyze_ckpt("./checkpoints/last_manual.ckpt")