ishanjmukherjee commited on
Commit
5ce9741
·
1 Parent(s): 27b8600

Update inspect keys script

Browse files
Files changed (1) hide show
  1. inspect-keys.py +32 -17
inspect-keys.py CHANGED
@@ -70,20 +70,20 @@ import torch, warnings, json, pathlib
70
  from transformers.models.auto.tokenization_auto import AutoTokenizer
71
  from transformers.models.auto.modeling_auto import AutoModelForCausalLM
72
 
73
- root = pathlib.Path(".")
74
- print("Loading tokenizer…")
75
- tok = AutoTokenizer.from_pretrained(root, trust_remote_code=True)
76
-
77
- print("Loading model… (this takes ~30 s on first run)")
78
- model = AutoModelForCausalLM.from_pretrained(
79
- root,
80
- torch_dtype="auto", # uses bf16/fp16 if your GPU supports it
81
- device_map="auto", # spreads across multiple GPUs if present
82
- trust_remote_code=True)
83
- hf_keys = set(model.state_dict().keys())
84
- print("\n--- HF Model Keys ---")
85
- for k in sorted(list(hf_keys)):
86
- print(k)
87
 
88
  ROOT = pathlib.Path(".")
89
  CKPT_PATH = ROOT / "model.safetensors"
@@ -92,8 +92,23 @@ ckpt_keys = set()
92
  if CKPT_PATH.exists():
93
  try:
94
  print("\nLoading checkpoint keys...")
95
- ckpt_keys = set(safetensors.torch.load_file(CKPT_PATH, device="cpu").keys())
96
- print("\n--- Checkpoint Keys ---")
97
- for k in sorted(list(ckpt_keys)): print(k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  except Exception as e:
99
  print(f"\nError loading checkpoint {CKPT_PATH}: {e}", file=sys.stderr)
 
70
  from transformers.models.auto.tokenization_auto import AutoTokenizer
71
  from transformers.models.auto.modeling_auto import AutoModelForCausalLM
72
 
73
+ # root = pathlib.Path(".")
74
+ # print("Loading tokenizer…")
75
+ # tok = AutoTokenizer.from_pretrained(root, trust_remote_code=True)
76
+
77
+ # print("Loading model… (this takes ~30 s on first run)")
78
+ # model = AutoModelForCausalLM.from_pretrained(
79
+ # root,
80
+ # torch_dtype="auto", # uses bf16/fp16 if your GPU supports it
81
+ # device_map="auto", # spreads across multiple GPUs if present
82
+ # trust_remote_code=True)
83
+ # hf_keys = set(model.state_dict().keys())
84
+ # print("\n--- HF Model Keys ---")
85
+ # for k in sorted(list(hf_keys)):
86
+ # print(k)
87
 
88
  ROOT = pathlib.Path(".")
89
  CKPT_PATH = ROOT / "model.safetensors"
 
92
  if CKPT_PATH.exists():
93
  try:
94
  print("\nLoading checkpoint keys...")
95
+ ckpt = safetensors.torch.load_file(CKPT_PATH, device="cpu")
96
+ ckpt_keys = set(ckpt.keys())
97
+ # print("\n--- Checkpoint Keys ---")
98
+ # for k in sorted(list(ckpt_keys)):
99
+ # print(k)
100
+ # print("\n--- End Checkpoint Keys ---")
101
+
102
+ non_tensors = {}
103
+ for k, v in ckpt.items():
104
+ if not isinstance(v, torch.Tensor):
105
+ non_tensors[k] = type(v)
106
+
107
+ if non_tensors:
108
+ print("\nWARNING: Found non-tensor objects in model.safetensors!")
109
+ for key, obj_type in non_tensors.items():
110
+ print(f" Key: '{key}', Type: {obj_type}")
111
+ else:
112
+ print("\nAll objects in model.safetensors are Tensors.")
113
  except Exception as e:
114
  print(f"\nError loading checkpoint {CKPT_PATH}: {e}", file=sys.stderr)