|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import torch |
|
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--model_path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to the fine-tuned checkpoint directory (e.g., ./checkpoints/checkpoint-16000)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--custom_model_path", |
|
|
type=str, |
|
|
required=False, |
|
|
help="(Optional) Path to the model implementation source if needed", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
print(f"Loading config from: {args.model_path}") |
|
|
config = AutoConfig.from_pretrained(args.model_path) |
|
|
|
|
|
if hasattr(config, "num_small_experts"): |
|
|
num_small_experts = config.num_small_experts |
|
|
else: |
|
|
raise ValueError("The model config does not contain 'num_small_experts'.") |
|
|
|
|
|
print(f"Number of small experts: {num_small_experts}") |
|
|
|
|
|
print("Loading model...") |
|
|
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) |
|
|
model.eval() |
|
|
|
|
|
print("Inspecting small expert weights...") |
|
|
total_params = 0 |
|
|
matched_params = 0 |
|
|
for name, param in model.named_parameters(): |
|
|
total_params += 1 |
|
|
if f"small_experts." in name: |
|
|
matched_params += 1 |
|
|
print(f"[Matched] {name} - shape: {tuple(param.shape)}") |
|
|
print(f"\nMatched {matched_params}/{total_params} parameters containing 'small_experts.'") |
|
|
|
|
|
print("Done.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|