""" Check model parameters from a checkpoint Usage: python check.py Example: python check.py config/train_reflow_web.py out-web/ckpt.pt """ import sys import os import torch # ----------------------------------------------------------------------------- # Configuration loading # ----------------------------------------------------------------------------- if len(sys.argv) != 3: print("ERROR: Invalid arguments!") print("Usage: python check.py ") print("Example: python check.py config/train_reflow_web.py out-web/ckpt.pt") sys.exit(1) config_file = sys.argv[1] checkpoint_file = sys.argv[2] if not os.path.exists(config_file): print(f"ERROR: Config file not found: {config_file}") sys.exit(1) if not os.path.exists(checkpoint_file): print(f"ERROR: Checkpoint file not found: {checkpoint_file}") sys.exit(1) # Load config print(f"Loading config from: {config_file}") exec(open(config_file).read()) # Load model configuration model_config = globals().get('model_config') if not model_config: print("ERROR: 'model_config' is required in config file") sys.exit(1) model_file = f"models/{model_config}.py" try: exec(open(model_file).read()) except FileNotFoundError: print(f"ERROR: Model file not found: {model_file}") sys.exit(1) # Import GPTConfig and GPT GPTConfig = globals()['GPTConfig'] GPT = globals()['GPT'] # Load checkpoint print(f"Loading checkpoint from: {checkpoint_file}") checkpoint = torch.load(checkpoint_file, map_location='cpu') model_args = checkpoint['model_args'] # Create model and load weights gptconf = GPTConfig(**model_args) model = GPT(gptconf) state_dict = checkpoint['model'] # Handle PyTorch 2.0+ compiled model keys unwanted_prefix = '_orig_mod.' for k in list(state_dict.keys()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) # Print model information print("\n" + "=" * 60) print("MODEL INFORMATION") print("=" * 60) print(f"\nModel Architecture: {model_config}") print(f"Checkpoint: {checkpoint_file}") print(f"\nModel Arguments:") for k, v in model_args.items(): print(f" {k:20s} = {v}") print(f"\nTotal Parameters: {model.get_num_params()/1e6:.2f}M") # Count parameters by component if hasattr(model, 'transformer'): print("\nParameters by component:") # Show wte (embedding) - for Reflow includes vocab_to_signals + signal_basis if hasattr(model.transformer, 'wte'): wte = model.transformer.wte if hasattr(wte, 'vocab_to_signals'): vocab_to_signals_params = wte.vocab_to_signals.weight.numel() print(f" transformer.wte.vocab_to_signals: {vocab_to_signals_params/1e6:>10.2f}M") if hasattr(wte, 'signal_basis'): signal_basis_params = wte.signal_basis.numel() print(f" transformer.wte.signal_basis: {signal_basis_params/1e6:>10.2f}M") wte_params = sum(p.numel() for p in wte.parameters()) print(f" transformer.wte (total): {wte_params/1e6:>10.2f}M") # Count transformer.h (layers) if hasattr(model.transformer, 'h'): h_params = sum(p.numel() for p in model.transformer.h.parameters()) print(f" transformer.h (all layers): {h_params/1e6:>10.2f}M") # Show ln_f if hasattr(model.transformer, 'ln_f'): ln_f_params = sum(p.numel() for p in model.transformer.ln_f.parameters()) print(f" transformer.ln_f: {ln_f_params/1e6:>10.2f}M") print(f"\nTotal trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M") print(f"Total non-trainable parameters: {sum(p.numel() for p in model.parameters() if not p.requires_grad)/1e6:.2f}M") # Training info if available if 'iter_num' in checkpoint: print(f"\nTraining Info:") print(f" iter_num: {checkpoint['iter_num']}") print(f" best_val_loss: {checkpoint.get('best_val_loss', 'N/A')}") print("=" * 60)