ishanjmukherjee commited on
Commit
c27420e
·
1 Parent(s): 4b35203

Show subtractions of HF and checkpoint key sets in inspection script

Browse files
Files changed (1) hide show
  1. inspect-keys.py +30 -30
inspect-keys.py CHANGED
@@ -70,20 +70,23 @@ import torch, warnings, json, pathlib
70
  from transformers.models.auto.tokenization_auto import AutoTokenizer
71
  from transformers.models.auto.modeling_auto import AutoModelForCausalLM
72
 
73
- # root = pathlib.Path(".")
74
- # print("Loading tokenizer…")
75
- # tok = AutoTokenizer.from_pretrained(root, trust_remote_code=True)
76
-
77
- # print("Loading model… (this takes ~30 s on first run)")
78
- # model = AutoModelForCausalLM.from_pretrained(
79
- # root,
80
- # torch_dtype="auto", # uses bf16/fp16 if your GPU supports it
81
- # device_map="auto", # spreads across multiple GPUs if present
82
- # trust_remote_code=True)
83
- # hf_keys = set(model.state_dict().keys())
84
- # print("\n--- HF Model Keys ---")
85
- # for k in sorted(list(hf_keys)):
86
- # print(k)
 
 
 
87
 
88
  ROOT = pathlib.Path(".")
89
  CKPT_PATH = ROOT / "model.safetensors"
@@ -91,24 +94,21 @@ CKPT_PATH = ROOT / "model.safetensors"
91
  ckpt_keys = set()
92
  if CKPT_PATH.exists():
93
  try:
 
94
  print("\nLoading checkpoint keys...")
95
  ckpt = safetensors.torch.load_file(CKPT_PATH, device="cpu")
96
  ckpt_keys = set(ckpt.keys())
97
- # print("\n--- Checkpoint Keys ---")
98
- # for k in sorted(list(ckpt_keys)):
99
- # print(k)
100
- # print("\n--- End Checkpoint Keys ---")
101
-
102
- non_tensors = {}
103
- for k, v in ckpt.items():
104
- if not isinstance(v, torch.Tensor):
105
- non_tensors[k] = type(v)
106
-
107
- if non_tensors:
108
- print("\nWARNING: Found non-tensor objects in model.safetensors!")
109
- for key, obj_type in non_tensors.items():
110
- print(f" Key: '{key}', Type: {obj_type}")
111
- else:
112
- print("\nAll objects in model.safetensors are Tensors.")
113
  except Exception as e:
114
  print(f"\nError loading checkpoint {CKPT_PATH}: {e}", file=sys.stderr)
 
70
  from transformers.models.auto.tokenization_auto import AutoTokenizer
71
  from transformers.models.auto.modeling_auto import AutoModelForCausalLM
72
 
73
+ root = pathlib.Path(".")
74
+ print("Loading tokenizer…")
75
+ tok = AutoTokenizer.from_pretrained(root, trust_remote_code=True)
76
+
77
+ print("Loading model… (this takes ~30 s on first run)")
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ root,
80
+ torch_dtype="auto", # uses bf16/fp16 if your GPU supports it
81
+ device_map="auto", # spreads across multiple GPUs if present
82
+ trust_remote_code=True)
83
+ hf_keys = set(model.state_dict().keys())
84
+
85
+ # Print HF model keys
86
+ print("\n--- HF Model Keys ---")
87
+ for k in sorted(list(hf_keys)):
88
+ print(k)
89
+ print("\n--- End HF Model Keys ---")
90
 
91
  ROOT = pathlib.Path(".")
92
  CKPT_PATH = ROOT / "model.safetensors"
 
94
  ckpt_keys = set()
95
  if CKPT_PATH.exists():
96
  try:
97
+ # Print checkpoint keys
98
  print("\nLoading checkpoint keys...")
99
  ckpt = safetensors.torch.load_file(CKPT_PATH, device="cpu")
100
  ckpt_keys = set(ckpt.keys())
101
+ print("\n--- Checkpoint Keys ---")
102
+ for k in sorted(list(ckpt_keys)):
103
+ print(k)
104
+ print("\n--- End Checkpoint Keys ---")
105
+
106
+ print("\nKeys in HF model but not in checkpoint:")
107
+ for k in sorted(list(hf_keys - ckpt_keys)):
108
+ print(k)
109
+
110
+ print("\nKeys in checkpoint but not in HF model:")
111
+ for k in sorted(list(ckpt_keys - hf_keys)):
112
+ print(k)
 
 
 
 
113
  except Exception as e:
114
  print(f"\nError loading checkpoint {CKPT_PATH}: {e}", file=sys.stderr)