LoRE / scripts /inspectexperts.py
Charlie81's picture
inspect
97a7f0a
#!/usr/bin/env python3
import argparse
import os
import torch
from transformers import AutoConfig, AutoModelForCausalLM
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the fine-tuned checkpoint directory (e.g., ./checkpoints/checkpoint-16000)",
)
parser.add_argument(
"--custom_model_path",
type=str,
required=False,
help="(Optional) Path to the model implementation source if needed",
)
args = parser.parse_args()
print(f"Loading config from: {args.model_path}")
config = AutoConfig.from_pretrained(args.model_path)
if hasattr(config, "num_small_experts"):
num_small_experts = config.num_small_experts
else:
raise ValueError("The model config does not contain 'num_small_experts'.")
print(f"Number of small experts: {num_small_experts}")
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16)
model.eval()
print("Inspecting small expert weights...")
total_params = 0
matched_params = 0
for name, param in model.named_parameters():
total_params += 1
if f"small_experts." in name:
matched_params += 1
print(f"[Matched] {name} - shape: {tuple(param.shape)}")
print(f"\nMatched {matched_params}/{total_params} parameters containing 'small_experts.'")
print("Done.")
if __name__ == "__main__":
main()