reFlow / check.py
reuAC's picture
Upload folder using huggingface_hub
672259a verified
"""
Check model parameters from a checkpoint
Usage:
python check.py <config_file> <checkpoint_file>
Example:
python check.py config/train_reflow_web.py out-web/ckpt.pt
"""
import sys
import os
import torch
# -----------------------------------------------------------------------------
# Configuration loading
# -----------------------------------------------------------------------------
if len(sys.argv) != 3:
print("ERROR: Invalid arguments!")
print("Usage: python check.py <config_file> <checkpoint_file>")
print("Example: python check.py config/train_reflow_web.py out-web/ckpt.pt")
sys.exit(1)
config_file = sys.argv[1]
checkpoint_file = sys.argv[2]
if not os.path.exists(config_file):
print(f"ERROR: Config file not found: {config_file}")
sys.exit(1)
if not os.path.exists(checkpoint_file):
print(f"ERROR: Checkpoint file not found: {checkpoint_file}")
sys.exit(1)
# Load config
print(f"Loading config from: {config_file}")
exec(open(config_file).read())
# Load model configuration
model_config = globals().get('model_config')
if not model_config:
print("ERROR: 'model_config' is required in config file")
sys.exit(1)
model_file = f"models/{model_config}.py"
try:
exec(open(model_file).read())
except FileNotFoundError:
print(f"ERROR: Model file not found: {model_file}")
sys.exit(1)
# Import GPTConfig and GPT
GPTConfig = globals()['GPTConfig']
GPT = globals()['GPT']
# Load checkpoint
print(f"Loading checkpoint from: {checkpoint_file}")
checkpoint = torch.load(checkpoint_file, map_location='cpu')
model_args = checkpoint['model_args']
# Create model and load weights
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
state_dict = checkpoint['model']
# Handle PyTorch 2.0+ compiled model keys
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
# Print model information
print("\n" + "=" * 60)
print("MODEL INFORMATION")
print("=" * 60)
print(f"\nModel Architecture: {model_config}")
print(f"Checkpoint: {checkpoint_file}")
print(f"\nModel Arguments:")
for k, v in model_args.items():
print(f" {k:20s} = {v}")
print(f"\nTotal Parameters: {model.get_num_params()/1e6:.2f}M")
# Count parameters by component
if hasattr(model, 'transformer'):
print("\nParameters by component:")
# Show wte (embedding) - for Reflow includes vocab_to_signals + signal_basis
if hasattr(model.transformer, 'wte'):
wte = model.transformer.wte
if hasattr(wte, 'vocab_to_signals'):
vocab_to_signals_params = wte.vocab_to_signals.weight.numel()
print(f" transformer.wte.vocab_to_signals: {vocab_to_signals_params/1e6:>10.2f}M")
if hasattr(wte, 'signal_basis'):
signal_basis_params = wte.signal_basis.numel()
print(f" transformer.wte.signal_basis: {signal_basis_params/1e6:>10.2f}M")
wte_params = sum(p.numel() for p in wte.parameters())
print(f" transformer.wte (total): {wte_params/1e6:>10.2f}M")
# Count transformer.h (layers)
if hasattr(model.transformer, 'h'):
h_params = sum(p.numel() for p in model.transformer.h.parameters())
print(f" transformer.h (all layers): {h_params/1e6:>10.2f}M")
# Show ln_f
if hasattr(model.transformer, 'ln_f'):
ln_f_params = sum(p.numel() for p in model.transformer.ln_f.parameters())
print(f" transformer.ln_f: {ln_f_params/1e6:>10.2f}M")
print(f"\nTotal trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M")
print(f"Total non-trainable parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad)/1e6:.2f}M")
# Training info if available
if 'iter_num' in checkpoint:
print(f"\nTraining Info:")
print(f" iter_num: {checkpoint['iter_num']}")
print(f" best_val_loss: {checkpoint.get('best_val_loss', 'N/A')}")
print("=" * 60)