File size: 1,546 Bytes
97a7f0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
#!/usr/bin/env python3
import argparse
import os
import torch
from transformers import AutoConfig, AutoModelForCausalLM
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the fine-tuned checkpoint directory (e.g., ./checkpoints/checkpoint-16000)",
)
parser.add_argument(
"--custom_model_path",
type=str,
required=False,
help="(Optional) Path to the model implementation source if needed",
)
args = parser.parse_args()
print(f"Loading config from: {args.model_path}")
config = AutoConfig.from_pretrained(args.model_path)
if hasattr(config, "num_small_experts"):
num_small_experts = config.num_small_experts
else:
raise ValueError("The model config does not contain 'num_small_experts'.")
print(f"Number of small experts: {num_small_experts}")
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
model.eval()
print("Inspecting small expert weights...")
total_params = 0
matched_params = 0
for name, param in model.named_parameters():
total_params += 1
if f"small_experts." in name:
matched_params += 1
print(f"[Matched] {name} - shape: {tuple(param.shape)}")
print(f"\nMatched {matched_params}/{total_params} parameters containing 'small_experts.'")
print("Done.")
if __name__ == "__main__":
main()
|