#!/usr/bin/env python3 import argparse import os import torch from transformers import AutoConfig, AutoModelForCausalLM def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model_path", type=str, required=True, help="Path to the fine-tuned checkpoint directory (e.g., ./checkpoints/checkpoint-16000)", ) parser.add_argument( "--custom_model_path", type=str, required=False, help="(Optional) Path to the model implementation source if needed", ) args = parser.parse_args() print(f"Loading config from: {args.model_path}") config = AutoConfig.from_pretrained(args.model_path) if hasattr(config, "num_small_experts"): num_small_experts = config.num_small_experts else: raise ValueError("The model config does not contain 'num_small_experts'.") print(f"Number of small experts: {num_small_experts}") print("Loading model...") model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) model.eval() print("Inspecting small expert weights...") total_params = 0 matched_params = 0 for name, param in model.named_parameters(): total_params += 1 if f"small_experts." in name: matched_params += 1 print(f"[Matched] {name} - shape: {tuple(param.shape)}") print(f"\nMatched {matched_params}/{total_params} parameters containing 'small_experts.'") print("Done.") if __name__ == "__main__": main()