File size: 4,200 Bytes
672259a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""

Check model parameters from a checkpoint



Usage:

  python check.py <config_file> <checkpoint_file>



Example:

  python check.py config/train_reflow_web.py out-web/ckpt.pt

"""
import sys
import os
import torch

# -----------------------------------------------------------------------------
# Configuration loading
# -----------------------------------------------------------------------------
if len(sys.argv) != 3:
    print("ERROR: Invalid arguments!")
    print("Usage: python check.py <config_file> <checkpoint_file>")
    print("Example: python check.py config/train_reflow_web.py out-web/ckpt.pt")
    sys.exit(1)

config_file = sys.argv[1]
checkpoint_file = sys.argv[2]

if not os.path.exists(config_file):
    print(f"ERROR: Config file not found: {config_file}")
    sys.exit(1)

if not os.path.exists(checkpoint_file):
    print(f"ERROR: Checkpoint file not found: {checkpoint_file}")
    sys.exit(1)

# Load config
print(f"Loading config from: {config_file}")
exec(open(config_file).read())

# Load model configuration
model_config = globals().get('model_config')
if not model_config:
    print("ERROR: 'model_config' is required in config file")
    sys.exit(1)

model_file = f"models/{model_config}.py"
try:
    exec(open(model_file).read())
except FileNotFoundError:
    print(f"ERROR: Model file not found: {model_file}")
    sys.exit(1)

# Import GPTConfig and GPT
GPTConfig = globals()['GPTConfig']
GPT = globals()['GPT']

# Load checkpoint
print(f"Loading checkpoint from: {checkpoint_file}")
checkpoint = torch.load(checkpoint_file, map_location='cpu')
model_args = checkpoint['model_args']

# Create model and load weights
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
state_dict = checkpoint['model']

# Handle PyTorch 2.0+ compiled model keys
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)

model.load_state_dict(state_dict)

# Print model information
print("\n" + "=" * 60)
print("MODEL INFORMATION")
print("=" * 60)

print(f"\nModel Architecture: {model_config}")
print(f"Checkpoint: {checkpoint_file}")

print(f"\nModel Arguments:")
for k, v in model_args.items():
    print(f"  {k:20s} = {v}")

print(f"\nTotal Parameters: {model.get_num_params()/1e6:.2f}M")

# Count parameters by component
if hasattr(model, 'transformer'):
    print("\nParameters by component:")

    # Show wte (embedding) - for Reflow includes vocab_to_signals + signal_basis
    if hasattr(model.transformer, 'wte'):
        wte = model.transformer.wte
        if hasattr(wte, 'vocab_to_signals'):
            vocab_to_signals_params = wte.vocab_to_signals.weight.numel()
            print(f"  transformer.wte.vocab_to_signals:       {vocab_to_signals_params/1e6:>10.2f}M")
        if hasattr(wte, 'signal_basis'):
            signal_basis_params = wte.signal_basis.numel()
            print(f"  transformer.wte.signal_basis:          {signal_basis_params/1e6:>10.2f}M")
        wte_params = sum(p.numel() for p in wte.parameters())
        print(f"  transformer.wte (total):                 {wte_params/1e6:>10.2f}M")

    # Count transformer.h (layers)
    if hasattr(model.transformer, 'h'):
        h_params = sum(p.numel() for p in model.transformer.h.parameters())
        print(f"  transformer.h (all layers):             {h_params/1e6:>10.2f}M")

    # Show ln_f
    if hasattr(model.transformer, 'ln_f'):
        ln_f_params = sum(p.numel() for p in model.transformer.ln_f.parameters())
        print(f"  transformer.ln_f:                       {ln_f_params/1e6:>10.2f}M")

print(f"\nTotal trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M")
print(f"Total non-trainable parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad)/1e6:.2f}M")

# Training info if available
if 'iter_num' in checkpoint:
    print(f"\nTraining Info:")
    print(f"  iter_num: {checkpoint['iter_num']}")
    print(f"  best_val_loss: {checkpoint.get('best_val_loss', 'N/A')}")

print("=" * 60)