| """ |
| Convert GLM-4.7-Flash ScatterMoE (SCM) format back to HuggingFace format. |
| |
| Usage: |
| python convert_scm_to_hf.py <input_scm_model_path> <output_hf_model_path> |
| """ |
| import glob |
| import os |
| import re |
| import shutil |
| import sys |
|
|
| import torch |
| from safetensors import safe_open |
| from safetensors.torch import save_file |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from configuration_glm_scm import Glm4MoeLiteSCMConfig |
|
|
| input_model = sys.argv[1] |
| output_model_path = sys.argv[2] |
|
|
| cfg_scm = Glm4MoeLiteSCMConfig.from_pretrained(input_model) |
| num_experts = cfg_scm.n_routed_experts |
|
|
| |
| new_state_dict = {} |
| pattern = f"{input_model}/model-*-of-*.safetensors" |
| files = sorted(glob.glob(pattern)) |
| if len(files) == 0: |
| pattern = f"{input_model}/model.safetensors" |
| files = sorted(glob.glob(pattern)) |
| if len(files) == 0: |
| raise FileNotFoundError(f"No safetensors files found in {input_model}") |
|
|
| tensors = {} |
| for file_path in files: |
| print(f"Loading {file_path}") |
| with safe_open(file_path, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| tensors[key] = f.get_tensor(key) |
|
|
| for key in tensors: |
| if "moe_mlp" not in key: |
| new_state_dict[key] = tensors[key] |
| elif "moe_mlp.output_experts" in key: |
| layer_num = int(re.search(r"layers\.(\d+)", key).group(1)) |
| for i, tensor in enumerate(torch.unbind(tensors[key])): |
| new_state_dict[ |
| f"model.layers.{layer_num}.mlp.experts.{i}.down_proj.weight" |
| ] = tensor.contiguous() |
| elif "moe_mlp.experts" in key: |
| layer_num = int(re.search(r"layers\.(\d+)", key).group(1)) |
| for i, tensor in enumerate(torch.unbind(tensors[key])): |
| ( |
| new_state_dict[f"model.layers.{layer_num}.mlp.experts.{i}.up_proj.weight"], |
| new_state_dict[f"model.layers.{layer_num}.mlp.experts.{i}.gate_proj.weight"], |
| ) = torch.chunk(tensor, 2, dim=0) |
|
|
| os.makedirs(output_model_path, exist_ok=True) |
| save_file(new_state_dict, os.path.join(output_model_path, "model.safetensors")) |
|
|
| |
| import json |
| config_dict = { |
| "architectures": ["Glm4MoeLiteForCausalLM"], |
| "model_type": "glm4_moe_lite", |
| "vocab_size": cfg_scm.vocab_size, |
| "hidden_size": cfg_scm.hidden_size, |
| "intermediate_size": cfg_scm.intermediate_size, |
| "moe_intermediate_size": cfg_scm.moe_intermediate_size, |
| "num_hidden_layers": cfg_scm.num_hidden_layers, |
| "num_attention_heads": cfg_scm.num_attention_heads, |
| "num_key_value_heads": cfg_scm.num_key_value_heads, |
| "n_shared_experts": cfg_scm.n_shared_experts, |
| "n_routed_experts": cfg_scm.n_routed_experts, |
| "routed_scaling_factor": cfg_scm.routed_scaling_factor, |
| "kv_lora_rank": cfg_scm.kv_lora_rank, |
| "q_lora_rank": cfg_scm.q_lora_rank, |
| "qk_rope_head_dim": cfg_scm.qk_rope_head_dim, |
| "v_head_dim": cfg_scm.v_head_dim, |
| "qk_nope_head_dim": cfg_scm.qk_nope_head_dim, |
| "n_group": cfg_scm.n_group, |
| "topk_group": cfg_scm.topk_group, |
| "num_experts_per_tok": cfg_scm.num_experts_per_tok, |
| "norm_topk_prob": cfg_scm.norm_topk_prob, |
| "topk_method": cfg_scm.topk_method, |
| "first_k_dense_replace": cfg_scm.first_k_dense_replace, |
| "hidden_act": cfg_scm.hidden_act, |
| "max_position_embeddings": cfg_scm.max_position_embeddings, |
| "rms_norm_eps": cfg_scm.rms_norm_eps, |
| "rope_theta": cfg_scm.rope_theta, |
| "rope_scaling": cfg_scm.rope_scaling, |
| "attention_bias": cfg_scm.attention_bias, |
| "attention_dropout": cfg_scm.attention_dropout, |
| "tie_word_embeddings": cfg_scm.tie_word_embeddings, |
| "pad_token_id": cfg_scm.pad_token_id, |
| "bos_token_id": cfg_scm.bos_token_id, |
| "eos_token_id": cfg_scm.eos_token_id, |
| "dtype": "bfloat16", |
| } |
| with open(os.path.join(output_model_path, "config.json"), "w") as f: |
| json.dump(config_dict, f, indent=2) |
|
|
| |
| for fname in os.listdir(input_model): |
| if fname.startswith("tokenizer") or fname in [ |
| "special_tokens_map.json", |
| "chat_template.jinja", |
| ]: |
| src = os.path.join(input_model, fname) |
| if os.path.isfile(src): |
| shutil.copy(src, os.path.join(output_model_path, fname)) |
|
|
| print(f"Model saved to {output_model_path}") |
|
|