| """
|
| Check model parameters from a checkpoint
|
|
|
| Usage:
|
| python check.py <config_file> <checkpoint_file>
|
|
|
| Example:
|
| python check.py config/train_reflow_web.py out-web/ckpt.pt
|
| """
|
| import sys
|
| import os
|
| import torch
|
|
|
|
|
|
|
|
|
| if len(sys.argv) != 3:
|
| print("ERROR: Invalid arguments!")
|
| print("Usage: python check.py <config_file> <checkpoint_file>")
|
| print("Example: python check.py config/train_reflow_web.py out-web/ckpt.pt")
|
| sys.exit(1)
|
|
|
| config_file = sys.argv[1]
|
| checkpoint_file = sys.argv[2]
|
|
|
| if not os.path.exists(config_file):
|
| print(f"ERROR: Config file not found: {config_file}")
|
| sys.exit(1)
|
|
|
| if not os.path.exists(checkpoint_file):
|
| print(f"ERROR: Checkpoint file not found: {checkpoint_file}")
|
| sys.exit(1)
|
|
|
|
|
| print(f"Loading config from: {config_file}")
|
| exec(open(config_file).read())
|
|
|
|
|
| model_config = globals().get('model_config')
|
| if not model_config:
|
| print("ERROR: 'model_config' is required in config file")
|
| sys.exit(1)
|
|
|
| model_file = f"models/{model_config}.py"
|
| try:
|
| exec(open(model_file).read())
|
| except FileNotFoundError:
|
| print(f"ERROR: Model file not found: {model_file}")
|
| sys.exit(1)
|
|
|
|
|
| GPTConfig = globals()['GPTConfig']
|
| GPT = globals()['GPT']
|
|
|
|
|
| print(f"Loading checkpoint from: {checkpoint_file}")
|
| checkpoint = torch.load(checkpoint_file, map_location='cpu')
|
| model_args = checkpoint['model_args']
|
|
|
|
|
| gptconf = GPTConfig(**model_args)
|
| model = GPT(gptconf)
|
| state_dict = checkpoint['model']
|
|
|
|
|
| unwanted_prefix = '_orig_mod.'
|
| for k in list(state_dict.keys()):
|
| if k.startswith(unwanted_prefix):
|
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
|
|
|
| model.load_state_dict(state_dict)
|
|
|
|
|
| print("\n" + "=" * 60)
|
| print("MODEL INFORMATION")
|
| print("=" * 60)
|
|
|
| print(f"\nModel Architecture: {model_config}")
|
| print(f"Checkpoint: {checkpoint_file}")
|
|
|
| print(f"\nModel Arguments:")
|
| for k, v in model_args.items():
|
| print(f" {k:20s} = {v}")
|
|
|
| print(f"\nTotal Parameters: {model.get_num_params()/1e6:.2f}M")
|
|
|
|
|
| if hasattr(model, 'transformer'):
|
| print("\nParameters by component:")
|
|
|
|
|
| if hasattr(model.transformer, 'wte'):
|
| wte = model.transformer.wte
|
| if hasattr(wte, 'vocab_to_signals'):
|
| vocab_to_signals_params = wte.vocab_to_signals.weight.numel()
|
| print(f" transformer.wte.vocab_to_signals: {vocab_to_signals_params/1e6:>10.2f}M")
|
| if hasattr(wte, 'signal_basis'):
|
| signal_basis_params = wte.signal_basis.numel()
|
| print(f" transformer.wte.signal_basis: {signal_basis_params/1e6:>10.2f}M")
|
| wte_params = sum(p.numel() for p in wte.parameters())
|
| print(f" transformer.wte (total): {wte_params/1e6:>10.2f}M")
|
|
|
|
|
| if hasattr(model.transformer, 'h'):
|
| h_params = sum(p.numel() for p in model.transformer.h.parameters())
|
| print(f" transformer.h (all layers): {h_params/1e6:>10.2f}M")
|
|
|
|
|
| if hasattr(model.transformer, 'ln_f'):
|
| ln_f_params = sum(p.numel() for p in model.transformer.ln_f.parameters())
|
| print(f" transformer.ln_f: {ln_f_params/1e6:>10.2f}M")
|
|
|
| print(f"\nTotal trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M")
|
| print(f"Total non-trainable parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad)/1e6:.2f}M")
|
|
|
|
|
| if 'iter_num' in checkpoint:
|
| print(f"\nTraining Info:")
|
| print(f" iter_num: {checkpoint['iter_num']}")
|
| print(f" best_val_loss: {checkpoint.get('best_val_loss', 'N/A')}")
|
|
|
| print("=" * 60)
|
|
|